Python seaborn 模块,set_palette() 实例源码

我们从Python开源项目中,提取了以下9个代码示例,用于说明如何使用seaborn.set_palette()

项目:keras-utilities    作者:cbaziotis    | 项目源码 | 文件源码
def on_train_begin(self, logs={}):
        sns.set_style("whitegrid")
        sns.set_style("whitegrid", {"grid.linewidth": 0.5,
                                    "lines.linewidth": 0.5,
                                    "axes.linewidth": 0.5})
        flatui = ["#9b59b6", "#3498db", "#95a5a6", "#e74c3c", "#34495e",
                  "#2ecc71"]
        sns.set_palette(sns.color_palette(flatui))
        # flatui = ["#9b59b6", "#3498db", "#95a5a6", "#e74c3c", "#34495e", "#2ecc71"]
        # sns.set_palette(sns.color_palette("Set2", 10))

        plt.ion()  # set plot to animated
        self.fig = plt.figure(
            figsize=(self.width * (1 + len(self.get_metrics(logs))),
                     self.height))  # width, height in inches

        # move it to the upper left corner
        move_figure(self.fig, 25, 25)
项目:openai_lab    作者:kengz    | 项目源码 | 文件源码
def scoped_mpl_import():
    import matplotlib
    matplotlib.rcParams['backend'] = MPL_BACKEND

    import matplotlib.pyplot as plt
    plt.rcParams['toolbar'] = 'None'  # mute matplotlib toolbar

    import seaborn as sns
    sns.set(style="whitegrid", color_codes=True, font_scale=1.0,
            rc={'lines.linewidth': 1.0,
                'backend': matplotlib.rcParams['backend']})
    palette = sns.color_palette("Blues_d")
    palette.reverse()
    sns.set_palette(palette)

    return (matplotlib, plt, sns)
项目:cohorts    作者:hammerlab    | 项目源码 | 文件源码
def set_styling():
    sb.set_style("white")
    red = colors.hex2color("#bb3f3f")
    blue = colors.hex2color("#5a86ad")
    deep_colors = sb.color_palette("deep")
    green = deep_colors[1]
    custom_palette = [red, blue, green]
    custom_palette.extend(deep_colors[3:])
    sb.set_palette(custom_palette)
    mpl.rcParams.update({"figure.figsize": np.array([6, 6]),
                         "legend.fontsize": 12,
                         "font.size": 16,
                         "axes.labelsize": 16,
                         "axes.labelweight": "bold",
                         "xtick.labelsize": 16,
                         "ytick.labelsize": 16})
项目:extract    作者:dblalock    | 项目源码 | 文件源码
def makeDishwasherFig(ax=None, zNorm=True, save=True):
    # ts = getGoodDishwasherTs()
    # ts.data = ar.zNormalizeCols(ts.data)
    ts = getFig1Ts(zNorm=True, whichTs=WHICH_DISHWASHER_TS)
    # ax = ts.plot(useWhichLabels=['ZC'], showLabels=False, capYLim=900)
    colors = DISHWASHER_COLOR_PALETTE * 3 # cycles thru almost three times
    colors[DISHWASHER_DIM_TO_HIGHLIGHT] = DISHWASHER_HIGHLIGHT_COLOR
    colors = colors[:ts.data.shape[1]]

    ts.data[:, 2] /= 2 # scale the ugliest dim to make pic prettier
    ax = ts.plot(showLabels=False, showBounds=False, capYLim=900, ax=ax,
        colors=colors) # resets palette...
    # ax = ts.plot(showLabels=False, showBounds=False, capYLim=900, ax=None) # works

    # ax.plot(ts.data[:, DISHWASHER_DIM_TO_HIGHLIGHT], color=DISHWASHER_HIGHLIGHT_COLOR)
    # sb.set_palette(DEFAULT_SB_PALETTE)

    sb.despine(left=True)
    ax.set_title("Dishwasher", y=TITLE_Y_POS)
    # ax.set_xlabel("Minute")
    plt.tight_layout()
    if save:
        saveFigWithName('dishwasher')

# ------------------------------------------------ MSRC
项目:temci    作者:parttimenerd    | 项目源码 | 文件源码
def reset_plt(self):
        """ Reset the current matplotlib plot style. """
        import matplotlib.pyplot as plt
        plt.gcf().subplots_adjust(bottom=0.15)
        if Settings()["report/xkcd_like_plots"]:
            import seaborn as sns
            sns.reset_defaults()
            mpl.use("agg")
            plt.xkcd()
        else:
            import seaborn as sns
            sns.reset_defaults()
            sns.set_style("darkgrid")
            sns.set_palette(sns.color_palette("muted"))
            mpl.use("agg")
项目:pygcam    作者:JGCRI    | 项目源码 | 文件源码
def setupPalette(count, pal=None):
    # See http://xkcd.com/color/rgb/. These were chosen to be different "enough".
    colors = ['grass green', 'canary yellow', 'dirty pink', 'azure', 'tangerine', 'strawberry',
              'yellowish green', 'gold', 'sea blue', 'lavender', 'orange brown', 'turquoise',
              'royal blue', 'cranberry', 'pea green', 'vermillion', 'sandy yellow', 'greyish brown',
              'magenta', 'silver', 'ivory', 'carolina blue', 'very light brown']

    palette = sns.color_palette(palette=pal, n_colors=count) if pal else sns.xkcd_palette(colors)
    sns.set_palette(palette, n_colors=count)


# For publications, call setupPlot("paper", font_scale=1.5)
项目:taut-sensoranalysis-python    作者:pjhartin    | 项目源码 | 文件源码
def plot_example(missed, acknowledged):
    sensor_miss = import_sensorfile(missed)
    sensor_ack = import_sensorfile(acknowledged)

    # Window data
    mag_miss = window_data(process_input(sensor_miss))
    mag_ack = window_data(process_input(sensor_ack))

    # Window data
    mag_miss = window_data(process_input(sensor_miss))
    mag_ack = window_data(process_input(sensor_ack))

    # Filter setup
    kernel = 15

    # apply filter
    mag_miss_filter = sci.medfilt(mag_miss, kernel)
    mag_ack_filter = sci.medfilt(mag_ack, kernel)

    # calibrate data
    mag_miss_cal = mf.calibrate_median(mag_miss)
    mag_miss_cal_filter = mf.calibrate_median(mag_miss_filter)

    mag_ack_cal = mf.calibrate_median(mag_ack)
    mag_ack_cal_filter = mf.calibrate_median(mag_ack_filter)

    # PLOT
    sns.set_style("white")
    current_palette = sns.color_palette('muted')
    sns.set_palette(current_palette)

    plt.figure(0)

    # Plot RAW missed and acknowledged reminders
    ax1 = plt.subplot2grid((2, 1), (0, 0))
    plt.ylim([-1.5, 1.5])
    plt.ylabel('Acceleration (g)')
    plt.plot(mag_miss_cal, label='Recording 1')
    plt.legend(loc='lower left')

    ax2 = plt.subplot2grid((2, 1), (1, 0))
    # Plot Missed Reminder RAW
    plt.ylim([-1.5, 1.5])
    plt.ylabel('Acceleration (g)')
    plt.xlabel('t (ms)')
    plt.plot(mag_ack_cal, linestyle='-', label='Recording 2')
    plt.legend(loc='lower left')

    # CALC AND SAVE STATS
    stats_one = sp.calc_stats_for_data_stream_as_dictionary(mag_miss_cal)
    stats_two = sp.calc_stats_for_data_stream_as_dictionary(mag_ack_cal)

    data = [stats_one, stats_two]
    write_to_csv(data, 'example_waves')

    plt.show()
项目:sdp_kmeans    作者:simonsfoundation    | 项目源码 | 文件源码
def test_reconstruction(X, gt, n_clusters, filename, from_file=False):
    Ds = sdp_kmeans(X, n_clusters, method='cvx')

    if from_file:
        data = scipy.io.loadmat('{}{}.mat'.format(dir_name, filename))
        rec_errors = data['rec_errors']
        k_values = data['k_values']
    else:
        k_values = np.arange(200 + len(X)) + 1
        rec_errors = []
        for k in k_values:
            print('{} / {}'.format(k, k_values[-1]))
            rec_errors_k = []
            for trials in range(50):
                Y = symnmf_admm(Ds[-1], k=k)
                rec_errors_k.append(check_completely_positivity(Ds[-1], Y))
            rec_errors.append(rec_errors_k)
        rec_errors = np.array(rec_errors)
        scipy.io.savemat('{}{}.mat'.format(dir_name, filename),
                         dict(rec_errors=rec_errors,
                              k_values=k_values))

    sns.set_style('white')

    plt.figure(tight_layout=True)
    gs = gridspec.GridSpec(1, 3)

    ax = plt.subplot(gs[0])
    plot_data_clustered(X, gt, ax=ax)

    for i, D_input in enumerate(Ds):
        ax = plt.subplot(gs[i + 1])
        plot_matrix(D_input, ax=ax)
        if i == 0:
            ax.set_title('Original Gramian')
        else:
            ax.set_title('Layer {} (k={})'.format(i, n_clusters))
    plt.savefig('{}{}_solution.pdf'.format(dir_name, filename))

    plt.figure(tight_layout=True)
    mean = np.mean(rec_errors, axis=1)
    std = np.std(rec_errors, axis=1)
    sns.set_palette('muted')
    plt.fill_between(np.squeeze(k_values), mean - 2 * std, mean + 2 * std,
                     alpha=0.3)
    plt.semilogy(np.squeeze(k_values), mean, linewidth=2)
    plt.semilogy([n_clusters, n_clusters], [mean.min(), mean.max()],
                 linestyle='--', linewidth=2)
    plt.xlabel('$r$', size='xx-large')
    plt.ylabel('Relative reconstruction error', size='xx-large')
    plt.ylim(np.floor(rec_errors.min() * 1e3) / 1e3, 1)
    plt.savefig('{}{}_curve.pdf'.format(dir_name, filename))
项目:saw_release    作者:kovibalu    | 项目源码 | 文件源码
def plot_2D_arrays(arrs, title='', xlabel='', xinterval=None, ylabel='', yinterval=None, line_names=[], simplified=False):
    """ Plots multiple arrays in the same plot based on the specifications. """
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt
    import seaborn as sns

    plt.clf()
    sns.set_style('darkgrid')
    sns.set(font_scale=1.5)
    sns.set_palette('husl', 8)

    for i, arr in enumerate(arrs):
        if arr.ndim != 2 or arr.shape[1] != 2:
            raise ValueError(
                'The array should be 2D and the second dimension should be 2!'
                ' Shape: %s' % str(arr.shape)
            )

        # Plot last one with black
        if i == len(arrs) - 1:
            plt.plot(arr[:, 0], arr[:, 1], color='black')
        else:
            plt.plot(arr[:, 0], arr[:, 1])

    # If simplified, we don't show text anywhere
    if not simplified:
        plt.title(title[:30])
        plt.xlabel(xlabel)
        plt.ylabel(ylabel)
        if line_names:
            plt.legend(line_names, loc=6, bbox_to_anchor=(1, 0.5))

    if xinterval:
        plt.xlim(xinterval)
    if yinterval:
        plt.ylim(yinterval)

    plt.tight_layout()


###############
# String handling
###############