Python seaborn 模块,color_palette() 实例源码

我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用seaborn.color_palette()

项目:Flavor-Network    作者:lingcheng99    | 项目源码 | 文件源码
def tsne_cluster_cuisine(df,sublist):
    lenlist=[0]
    df_sub = df[df['cuisine']==sublist[0]]
    lenlist.append(df_sub.shape[0])
    for cuisine in sublist[1:]:
        temp = df[df['cuisine']==cuisine]
        df_sub = pd.concat([df_sub, temp],axis=0,ignore_index=True)
        lenlist.append(df_sub.shape[0])
    df_X = df_sub.drop(['cuisine','recipeName'],axis=1)
    print df_X.shape, lenlist

    dist = squareform(pdist(df_X, metric='cosine'))
    tsne = TSNE(metric='precomputed').fit_transform(dist)

    palette = sns.color_palette("hls", len(sublist))
    plt.figure(figsize=(10,10))
    for i,cuisine in enumerate(sublist):
        plt.scatter(tsne[lenlist[i]:lenlist[i+1],0],\
        tsne[lenlist[i]:lenlist[i+1],1],c=palette[i],label=sublist[i])
    plt.legend()

#interactive plot with boken; set up for four categories, with color palette; pass in df for either ingredient or flavor
项目:PorousMediaLab    作者:biogeochemistry    | 项目源码 | 文件源码
def saturation_index_countour(lab, elem1, elem2, Ks, labels=False):
    plt.figure()
    plt.title('Saturation index %s%s' % (elem1, elem2))
    resoluion = 100
    n = math.ceil(lab.time.size / resoluion)
    plt.xlabel('Time')
    z = np.log10((lab.species[elem1]['concentration'][:, ::n] + 1e-8) * (
        lab.species[elem2]['concentration'][:, ::n] + 1e-8) / lab.constants[Ks])
    lim = np.max(abs(z))
    lim = np.linspace(-lim - 0.1, +lim + 0.1, 51)
    X, Y = np.meshgrid(lab.time[::n], -lab.x)
    plt.xlabel('Time')
    CS = plt.contourf(X, Y, z, 20, cmap=ListedColormap(sns.color_palette(
        "RdBu_r", 101)), origin='lower', levels=lim, extend='both')
    if labels:
        plt.clabel(CS, inline=1, fontsize=10, colors='w')
    # cbar = plt.colorbar(CS)
    if labels:
        plt.clabel(CS, inline=1, fontsize=10, colors='w')
    cbar = plt.colorbar(CS)
    plt.ylabel('Depth')
    ax = plt.gca()
    ax.ticklabel_format(useOffset=False)
    cbar.ax.set_ylabel('Saturation index %s%s' % (elem1, elem2))
    return ax
项目:PorousMediaLab    作者:biogeochemistry    | 项目源码 | 文件源码
def contour_plot_of_rates(lab, r, labels=False, last_year=False):
    plt.figure()
    plt.title('{}'.format(r))
    resoluion = 100
    n = math.ceil(lab.time.size / resoluion)
    if last_year:
        k = n - int(1 / lab.dt)
    else:
        k = 1
    z = lab.estimated_rates[r][:, k - 1:-1:n]
    # lim = np.max(np.abs(z))
    # lim = np.linspace(-lim - 0.1, +lim + 0.1, 51)
    X, Y = np.meshgrid(lab.time[k::n], -lab.x)
    plt.xlabel('Time')
    CS = plt.contourf(X, Y, z, 20, cmap=ListedColormap(
        sns.color_palette("Blues", 51)))
    if labels:
        plt.clabel(CS, inline=1, fontsize=10, colors='w')
    cbar = plt.colorbar(CS)
    plt.ylabel('Depth')
    ax = plt.gca()
    ax.ticklabel_format(useOffset=False)
    cbar.ax.set_ylabel('Rate %s [M/V/T]' % r)
    return ax
项目:PorousMediaLab    作者:biogeochemistry    | 项目源码 | 文件源码
def contour_plot_of_delta(lab, element, labels=False, last_year=False):
    plt.figure()
    plt.title('Rate of %s consumption/production' % element)
    resoluion = 100
    n = math.ceil(lab.time.size / resoluion)
    if last_year:
        k = n - int(1 / lab.dt)
    else:
        k = 1
    z = lab.species[element]['rates'][:, k - 1:-1:n]
    lim = np.max(np.abs(z))
    lim = np.linspace(-lim - 0.1, +lim + 0.1, 51)
    X, Y = np.meshgrid(lab.time[k:-1:n], -lab.x)
    plt.xlabel('Time')
    CS = plt.contourf(X, Y, z, 20, cmap=ListedColormap(sns.color_palette(
        "RdBu_r", 101)), origin='lower', levels=lim, extend='both')
    if labels:
        plt.clabel(CS, inline=1, fontsize=10, colors='w')
    cbar = plt.colorbar(CS)
    plt.ylabel('Depth')
    ax = plt.gca()
    ax.ticklabel_format(useOffset=False)
    cbar.ax.set_ylabel('Rate of %s change $[\Delta/T]$' % element)
    return ax
项目:coquery    作者:gkunter    | 项目源码 | 文件源码
def set_defaults(self):
        self.options["color_palette"] = "Paired"
        if self._levels:
            self.options["color_number"] = len(self._levels[-1])
        else:
            self.options["color_number"] = 1

        if len(self._number_columns) == 0:
            raise VisualizationInvalidDataError

        if len(self._number_columns) == 1:
            if self.cumulative:
                self.options["label_y_axis"] = "Cumulative probability"
            else:
                self.options["label_y_axis"] = "Density"
        else:
            self.options["label_y_axis"] = self._number_columns[-2]
        self.options["label_x_axis"] = self._number_columns[-1]

        if len(self._groupby) == 1:
            self.options["label_legend"] = self._groupby[-1]

        super(Visualizer, self).set_defaults()
项目:coquery    作者:gkunter    | 项目源码 | 文件源码
def set_defaults(self):
        self.options["color_palette"] = "Paired"
        if self._levels:
            self.options["color_number"] = len(self._levels[-1])
        else:
            self.options["color_number"] = 1

        if len(self._number_columns) == 0:
            raise VisualizationInvalidDataError

        if len(self._number_columns) == 1:
            self.options["label_x_axis"] = self._default
        else:
            self.options["label_x_axis"] = self._number_columns[-2]
        self.options["label_y_axis"] = self._number_columns[-1]

        if len(self._groupby) == 1:
            self.options["label_legend"] = self._groupby[-1]

        super(Visualizer, self).set_defaults()
项目:coquery    作者:gkunter    | 项目源码 | 文件源码
def accept(self):
        self.options["label_main"] = str(self.ui.label_main.text())
        self.options["label_x_axis"] = str(self.ui.label_x_axis.text())
        self.options["label_y_axis"] = str(self.ui.label_y_axis.text())
        self.options["label_legend"] = str(self.ui.label_legend.text())
        self.options["label_legend_columns"] = int(self.ui.spin_columns.value())

        try:
            self.options["color_transparency"] = float(self.ui.slide_transparency.value())
        except AttributeError:
            pass

        self.options["color_palette"] = self.palette_name
        self.options["color_palette_values"] = self.get_current_palette()
        if len(self.options["color_palette_values"]) < self.options.get("color_number", 6):
            self.options["color_palette_values"] = (self.options["color_palette_values"] * self.options.get("color_number", 6))[:self.options.get("color_number", 6)]

        for x in ["main", "x_axis", "x_ticks", "y_axis", "y_ticks", "legend", "legend_entries"]:
            self.options["font_{}".format(x)] = getattr(self.ui, "label_sample_{}".format(x)).font()

        super(FigureOptions, self).accept()
        options.settings.setValue("figureoptions_size", self.size())
项目:augur    作者:nextstrain    | 项目源码 | 文件源码
def plot_all(self):
        from matplotlib import pyplot as plt
        import seaborn as sns
        cols = sns.color_palette(n_colors=6)
        p = self.global_pivots
        fig = plt.figure(figsize = (20,7))
        ax = plt.subplot(111)
        for clade in  self.global_freqs.keys():
            f = self.global_freqs[clade]
            if np.max(f)>0.2:
                ax.plot(p, f, c=cols[clade%len(cols)], alpha=0.3, ls='--')

        for tint, (p,freq) in self.train_frequencies.iteritems():
            for clade in  freq.keys():
                if np.max(freq[clade])>0.2:
                    ax.plot(p, freq[clade], c=cols[clade%len(cols)], alpha=0.5)
项目:openai_lab    作者:kengz    | 项目源码 | 文件源码
def scoped_mpl_import():
    import matplotlib
    matplotlib.rcParams['backend'] = MPL_BACKEND

    import matplotlib.pyplot as plt
    plt.rcParams['toolbar'] = 'None'  # mute matplotlib toolbar

    import seaborn as sns
    sns.set(style="whitegrid", color_codes=True, font_scale=1.0,
            rc={'lines.linewidth': 1.0,
                'backend': matplotlib.rcParams['backend']})
    palette = sns.color_palette("Blues_d")
    palette.reverse()
    sns.set_palette(palette)

    return (matplotlib, plt, sns)
项目:hypertools    作者:ContextLab    | 项目源码 | 文件源码
def vals2colors(vals,cmap='GnBu_d',res=100):
    """Maps values to colors
    Args:
    values (list or list of lists) - list of values to map to colors
    cmap (str) - color map (default is 'husl')
    res (int) - resolution of the color map (default: 100)
    Returns:
    list of rgb tuples
    """
    # flatten if list of lists
    if any(isinstance(el, list) for el in vals):
        vals = list(itertools.chain(*vals))

    # get palette from seaborn
    palette = np.array(sns.color_palette(cmap, res))
    ranks = np.digitize(vals, np.linspace(np.min(vals), np.max(vals)+1, res+1)) - 1
    return [tuple(i) for i in palette[ranks, :]]
项目:qtim_ROP    作者:QTIM-Lab    | 项目源码 | 文件源码
def plot_heatmaps(img_arr, img_names, titles, heatmaps, labels, out_dir):

    # construct cmap
    pal = sns.diverging_palette(240, 10, n=30, center="dark")
    my_cmap = ListedColormap(sns.color_palette(pal).as_hex())

    min_val, max_val = np.min(heatmaps), np.max(heatmaps)

    for j, (img, img_name, h_map, title, y) in enumerate(zip(img_arr, img_names, heatmaps, titles, labels)):

        fig, ax = plt.subplots()
        img = np.transpose(img, (1, 2, 0))
        plt.clf()
        plt.imshow(img, cmap='Greys', interpolation='bicubic')
        plt.imshow(h_map, cmap=my_cmap, alpha=0.7, interpolation='nearest') #, vmin=-.05, vmax=.05)
        plt.colorbar()
        plt.axis('off')
        plt.title(title)
        class_name = CLASSES[y]
        class_dir = make_sub_dir(out_dir, class_name)
        plt.savefig(join(class_dir, img_name), bbox_inches='tight', dpi=300)
项目:MLAlgorithms    作者:rushter    | 项目源码 | 文件源码
def plot(self, ax=None, holdon=False):
        sns.set(style="white")

        data = self.X

        if ax is None:
            _, ax = plt.subplots()



        for i, index in enumerate(self.clusters):
            point = np.array(data[index]).T
            ax.scatter(*point, c=sns.color_palette("hls", self.K + 1)[i])

        for point in self.centroids:
            ax.scatter(*point, marker='x', linewidths=10)

        if not holdon:
            plt.show()
项目:unrolled-gan    作者:musyoku    | 项目源码 | 文件源码
def plot_kde(data, dir=None, filename="kde", color="Greens"):
    if dir is None:
        raise Exception()
    try:
        os.mkdir(dir)
    except:
        pass
    fig = pylab.gcf()
    fig.set_size_inches(16.0, 16.0)
    pylab.clf()
    bg_color  = sns.color_palette(color, n_colors=256)[0]
    ax = sns.kdeplot(data[:, 0], data[:,1], shade=True, cmap=color, n_levels=30, clip=[[-4, 4]]*2)
    ax.set_axis_bgcolor(bg_color)
    kde = ax.get_figure()
    pylab.xlim(-4, 4)
    pylab.ylim(-4, 4)
    kde.savefig("{}/{}.png".format(dir, filename))
项目:unrolled-gan    作者:musyoku    | 项目源码 | 文件源码
def plot_kde(data, dir=None, filename="kde", color="Greens"):
    if dir is None:
        raise Exception()
    try:
        os.mkdir(dir)
    except:
        pass
    fig = pylab.gcf()
    fig.set_size_inches(16.0, 16.0)
    pylab.clf()
    bg_color  = sns.color_palette(color, n_colors=256)[0]
    ax = sns.kdeplot(data[:, 0], data[:,1], shade=True, cmap=color, n_levels=30, clip=[[-4, 4]]*2)
    ax.set_axis_bgcolor(bg_color)
    kde = ax.get_figure()
    pylab.xlim(-4, 4)
    pylab.ylim(-4, 4)
    kde.savefig("{}/{}".format(dir, filename))
项目:Iris-Classification-with-Heroku    作者:gaborvecsei    | 项目源码 | 文件源码
def plotPrediction(pred):
    """
    Plots the prediction than encodes it to base64
    :param pred: prediction accuracies
    :return: base64 encoded image as string
    """

    labels = ['setosa', 'versicolor', 'virginica']
    sns.set_context(rc={"figure.figsize": (5, 5)})
    with sns.color_palette("RdBu_r", 3):
        ax = sns.barplot(x=labels, y=pred)
    ax.set(ylim=(0, 1))

    # Base64 encode the plot
    stringIObytes = cStringIO.StringIO()
    sns.plt.savefig(stringIObytes, format='jpg')
    sns.plt.show()
    stringIObytes.seek(0)
    base64data = base64.b64encode(stringIObytes.read())
    return base64data
项目:pymoku    作者:liquidinstruments    | 项目源码 | 文件源码
def phase1_plot_setup():
    # Set up a 1x2 plot
    f, (ax1, ax2) = plt.subplots(1,2)
    f.suptitle('Phase 1 - Rise Times', fontsize=18, fontweight='bold')

    # Choose a colour palette and font size/style
    colours = sns.color_palette("muted")
    sns.set_context('poster')

    # Maximise the plotting window
    plot_backend = matplotlib.get_backend()
    mng = plt.get_current_fig_manager()
    if plot_backend == 'TkAgg':
        mng.resize(*mng.window.maxsize())
    elif plot_backend == 'wxAgg':
        mng.frame.Maximize(True)
    elif plot_backend == 'Qt4Agg':
        mng.window.showMaximized()

    return f, ax1, ax2
项目:pymoku    作者:liquidinstruments    | 项目源码 | 文件源码
def phase2_plot_setup():
    # Set up a 1x1 plot
    f, ax1 = plt.subplots(1,1)
    f.suptitle('Phase 2 - Line Width', fontsize=18, fontweight='bold')

    # Choose a colour palette and font size/style
    colours = sns.color_palette("muted")
    sns.set_context('poster')

    # Maximise the plotting window
    plot_backend = matplotlib.get_backend()
    mng = plt.get_current_fig_manager()
    if plot_backend == 'TkAgg':
        mng.resize(*mng.window.maxsize())
    elif plot_backend == 'wxAgg':
        mng.frame.Maximize(True)
    elif plot_backend == 'Qt4Agg':
        mng.window.showMaximized()

    return f, ax1
项目:flexCE    作者:bretthandrews    | 项目源码 | 文件源码
def get_colors(cfg):
    """Get colors from config file or set them with seaborn color palette.

    Args:
        cfg (dict): config settings.

    Returns:
        list: colors
    """
    try:
        colors = cfg['Plot']['colors']
        if not isinstance(colors, list):
            colors = [colors]
    except KeyError:
        colors = sns.color_palette('bright')
    return colors
项目:VASC    作者:wang-research    | 项目源码 | 文件源码
def print_heatmap( points,label,id_map ):
    '''
    points: N_samples * N_features
    label: (int) N_samples
    id_map: map label id to its name
    '''
    # = sns.color_palette("RdBu_r", max(label)+1)
    #cNorm = colors.Normalize(vmin=0,vmax=max(label)) #normalise the colormap
    #scalarMap = cm.ScalarMappable(norm=cNorm,cmap='Paired') #map numbers to colors

    index = [id_map[i] for i in label]
    df = DataFrame( 
            points,
            columns = list(range(points.shape[1])),
            index = index
            )
    row_color = [current_palette[i] for i in label]

    cmap = sns.cubehelix_palette(as_cmap=True, rot=-.3, light=1)
    g = sns.clustermap( df,cmap=cmap,row_colors=row_color,col_cluster=False,xticklabels=False,yticklabels=False) #,standard_scale=1 )

    return g.fig
项目:LSGAN    作者:musyoku    | 项目源码 | 文件源码
def plot_kde(data, dir=None, filename="kde", color="Greens"):
    if dir is None:
        raise Exception()
    try:
        os.mkdir(dir)
    except:
        pass
    fig = pylab.gcf()
    fig.set_size_inches(16.0, 16.0)
    pylab.clf()
    bg_color  = sns.color_palette(color, n_colors=256)[0]
    ax = sns.kdeplot(data[:, 0], data[:,1], shade=True, cmap=color, n_levels=30, clip=[[-4, 4]]*2)
    ax.set_axis_bgcolor(bg_color)
    kde = ax.get_figure()
    pylab.xlim(-4, 4)
    pylab.ylim(-4, 4)
    kde.savefig("{}/{}.png".format(dir, filename))
项目:LSGAN    作者:musyoku    | 项目源码 | 文件源码
def plot_kde(data, dir=None, filename="kde", color="Greens"):
    if dir is None:
        raise Exception()
    try:
        os.mkdir(dir)
    except:
        pass
    fig = pylab.gcf()
    fig.set_size_inches(16.0, 16.0)
    pylab.clf()
    bg_color  = sns.color_palette(color, n_colors=256)[0]
    ax = sns.kdeplot(data[:, 0], data[:,1], shade=True, cmap=color, n_levels=30, clip=[[-4, 4]]*2)
    ax.set_axis_bgcolor(bg_color)
    kde = ax.get_figure()
    pylab.xlim(-4, 4)
    pylab.ylim(-4, 4)
    kde.savefig("{}/{}".format(dir, filename))
项目:cohorts    作者:hammerlab    | 项目源码 | 文件源码
def set_styling():
    sb.set_style("white")
    red = colors.hex2color("#bb3f3f")
    blue = colors.hex2color("#5a86ad")
    deep_colors = sb.color_palette("deep")
    green = deep_colors[1]
    custom_palette = [red, blue, green]
    custom_palette.extend(deep_colors[3:])
    sb.set_palette(custom_palette)
    mpl.rcParams.update({"figure.figsize": np.array([6, 6]),
                         "legend.fontsize": 12,
                         "font.size": 16,
                         "axes.labelsize": 16,
                         "axes.labelweight": "bold",
                         "xtick.labelsize": 16,
                         "ytick.labelsize": 16})
项目:Default-Credit-Card-Prediction    作者:AlexPnt    | 项目源码 | 文件源码
def visualize_pca2D(X,y):
    """
    Visualize the first two principal components

    Keyword arguments:
    X -- The feature vectors
    y -- The target vector
    """
    pca = PCA(n_components = 2)
    principal_components = pca.fit_transform(X)

    palette = sea.color_palette()
    plt.scatter(principal_components[y==0, 0], principal_components[y==0, 1], marker='s',color='green',label="Paid", alpha=0.5,edgecolor='#262626', facecolor=palette[1], linewidth=0.15)
    plt.scatter(principal_components[y==1, 0], principal_components[y==1, 1], marker='^',color='red',label="Default", alpha=0.5,edgecolor='#262626''', facecolor=palette[2], linewidth=0.15)

    leg = plt.legend(loc='upper right', fancybox=True)
    leg.get_frame().set_alpha(0.5)
    plt.title("Two-Dimensional Principal Component Analysis")
    plt.tight_layout

    #save fig
    output_dir='img'
    save_fig(output_dir,'{}/pca2D.png'.format(output_dir))
项目:Default-Credit-Card-Prediction    作者:AlexPnt    | 项目源码 | 文件源码
def visualize_pca3D(X,y):
    """
    Visualize the first three principal components

    Keyword arguments:
    X -- The feature vectors
    y -- The target vector
    """
    pca = PCA(n_components = 3)
    principal_components = pca.fit_transform(X)

    fig = pylab.figure()
    ax = Axes3D(fig)
    # azm=30
    # ele=30
    # ax.view_init(azim=azm,elev=ele)

    palette = sea.color_palette()
    ax.scatter(principal_components[y==0, 0], principal_components[y==0, 1], principal_components[y==0, 2], label="Paid", alpha=0.5, 
                edgecolor='#262626', c=palette[1], linewidth=0.15)
    ax.scatter(principal_components[y==1, 0], principal_components[y==1, 1], principal_components[y==1, 2],label="Default", alpha=0.5, 
                edgecolor='#262626''', c=palette[2], linewidth=0.15)

    ax.legend()
    plt.show()
项目:extract    作者:dblalock    | 项目源码 | 文件源码
def makeGarbageDimTs():
    np.random.seed(123)
    seqLen = 750
    squareLen = seqLen / 17.
    seq = synth.notSoRandomWalk(seqLen, std=.05,
        trendFilterLength=(seqLen // 2), lpfLength=2)

    sb.set_style('white')
    _, ax = plt.subplots()
    # color = sb.color_palette()[1]
    # ax.plot(seq, lw=4, color="#660000") # red I'm using in keynote
    ax.plot(seq, lw=4, color="#CC0000") # red I'm using in keynote
    ax.set_xlim([-squareLen, seqLen + squareLen])
    ax.set_ylim([np.min(seq) * 2, np.max(seq) * 2])

    sb.despine(left=True)
    plt.show()

# def makeMethodsWarpedTs():


# ================================================================ Better Fig1
项目:crop-seq    作者:epigen    | 项目源码 | 文件源码
def get_level_colors(index):
    pallete = sns.color_palette("colorblind") * int(1e6)

    colors = list()

    if hasattr(index, "levels"):
        for level in index.levels:
            color_dict = dict(zip(level, pallete))
            level_colors = [color_dict[x] for x in index.get_level_values(level.name)]
            colors.append(level_colors)
    else:
        color_dict = dict(zip(set(index), pallete))
        index_colors = [color_dict[x] for x in index]
        colors.append(index_colors)

    return colors
项目:IgDiscover    作者:NBISweden    | 项目源码 | 文件源码
def plot_clustermap(sequences, title, plotpath, size=300, dpi=200):
    """
    Plot a clustermap of the given sequences

    size -- Downsample to this many sequences
    title -- plot title

    Return the number of clusters.
    """
    logger.info('Clustering %d sequences (downsampled to at most %d)', len(sequences), size)
    sequences = downsampled(sequences, size)
    df, linkage, clusters = cluster_sequences(sequences)

    palette = sns.color_palette([(0.15, 0.15, 0.15)])
    palette += sns.color_palette('Spectral', n_colors=max(clusters), desat=0.9)
    row_colors = [ palette[cluster_id] for cluster_id in clusters ]
    cm = sns.clustermap(df,
            row_linkage=linkage,
            col_linkage=linkage,
            row_colors=row_colors,
            linewidths=None,
            linecolor='none',
            figsize=(210/25.4, 210/25.4),
            cmap='Blues',
            xticklabels=False,
            yticklabels=False
    )
    if title is not None:
        cm.fig.suptitle(title)
    cm.savefig(plotpath, dpi=dpi)

    # free the memory used by the plot
    import matplotlib.pyplot as plt
    plt.close('all')

    return len(set(clusters))
项目:exatomic    作者:exa-analytics    | 项目源码 | 文件源码
def plot_energy(curv, color=None, title='', figsize=(21,5),
                nylabel=3, nxlabel=5, fontsize=24):
    """
    Accepts the output of compute_curvature or combine_curvature and
    returns a figure with appropriate styling.
    """
    def _deltaE(col):
        if col.name == 'n': return col
        cat = np.linspace(col.values[0], 0, 51)
        an = np.linspace(0, col.values[-1], 51)
        return col - np.hstack([cat, an])
    figargs = {'figsize': figsize}
    fig = _gen_figure(nxplot=1, nyplot=3, nxlabel=nxlabel,
                      figargs=figargs, fontsize=fontsize)
    ax, axnone, ax1 = fig.get_axes()
    axnone.set_visible(False)
    color = sns.color_palette('cubehelix', curv.shape[1] - 1) \
            if color is None else color
    plargs = {'x': 'n', 'color': color, 'title': title, 'legend': False}
    curvy = curv.apply(_deltaE)
    curv.plot(ax=ax, **plargs)
    ax.set_ylim([curv.min().min(), curv.max().max()])
    ax.set_ylabel('$\Delta$E (eV)', fontsize=fontsize)
    ax.set_xlabel('$\Delta$N', fontsize=fontsize)
    curvy.plot(ax=ax1, **plargs)
    del curvy['n']
    ax1.set_ylim([curvy.min().min(), curvy.max().max()])
    ax1.set_ylabel('$\Delta \Delta$E (eV)', fontsize=fontsize)
    ax.set_xlabel('$\Delta$N', fontsize=fontsize)
    loc = [1.2, (9 - curv.shape[1]) / 25]
    ax.legend(*ax.get_legend_handles_labels(), loc=loc)
    return fig
项目:QDREN    作者:andreamad8    | 项目源码 | 文件源码
def plot_dist(train_y,dev_y,test_y):
    import seaborn as sns
    import matplotlib.pyplot as plt
    plt.rc('text', usetex=True)
    plt.rc('font', family='Times-Roman')
    sns.set_style(style='white')
    color = sns.color_palette("Set2", 10)
    fig = plt.figure(figsize=(8,12))

    ax1 = fig.add_subplot(3, 1, 1)
    # plt.title("Label distribution",fontsize=20)
    sns.distplot(train_y,kde=False,label='Training', hist=True, norm_hist=True,color="blue")
    ax1.set_xlabel("Answer")
    ax1.set_ylabel("Frequency")
    ax1.set_xlim([0,500])
    plt.legend(loc='best')

    ax2 = fig.add_subplot(3, 1, 2)
    sns.distplot(dev_y,kde=False,label='Validation', hist=True, norm_hist=True,color="green")
    ax2.set_xlabel("Answer")
    ax2.set_ylabel("Frequency")
    ax2.set_xlim([0,500])
    plt.legend(loc='best')

    ax3 = fig.add_subplot(3, 1, 3)
    sns.distplot(test_y,kde=False,label='Test', hist=True, norm_hist=True,color="red")
    ax3.set_xlabel("Answer")
    ax3.set_ylabel("Frequency")
    ax3.set_xlim([0,500])
    plt.legend(loc='best')



    plt.savefig('checkpoints/label_dist.pdf', format='pdf', dpi=300)

    plt.show()
项目:PorousMediaLab    作者:biogeochemistry    | 项目源码 | 文件源码
def contour_plot(lab, element, labels=False, days=False, last_year=False):
    plt.figure()
    plt.title(element + ' concentration')
    resoluion = 100
    n = math.ceil(lab.time.size / resoluion)
    if last_year:
        k = n - int(1 / lab.dt)
    else:
        k = 1
    if days:
        X, Y = np.meshgrid(lab.time[k::n] * 365, -lab.x)
        plt.xlabel('Time')
    else:
        X, Y = np.meshgrid(lab.time[k::n], -lab.x)
        plt.xlabel('Time')
    z = lab.species[element]['concentration'][:, k - 1:-1:n]
    CS = plt.contourf(X, Y, z, 51, cmap=ListedColormap(
        sns.color_palette("Blues", 51)), origin='lower')
    if labels:
        plt.clabel(CS, inline=1, fontsize=10, colors='w')
    cbar = plt.colorbar(CS)
    plt.ylabel('Depth')
    ax = plt.gca()
    ax.ticklabel_format(useOffset=False)
    cbar.ax.set_ylabel('%s [M/V]' % element)
    if element == 'Temperature':
        plt.title('Temperature contour plot')
        cbar.ax.set_ylabel('Temperature, C')
    if element == 'pH':
        plt.title('pH contour plot')
        cbar.ax.set_ylabel('pH')
    return ax
项目:coquery    作者:gkunter    | 项目源码 | 文件源码
def set_defaults(self):
        """
        Set the plot defaults.
        """
        # choose the "Paired" palette if the number of grouping factor
        # levels is even and below 13, or the "Set3" palette otherwise:
        if len(self._levels[1 if len(self._groupby) == 2 else 0]) in (2, 4, 6, 8, 12):
            self.options["color_palette"] = "Paired"
        else:
            # use 'Set3', a quantitative palette, if there are two grouping
            # factors, or a palette diverging from Red to Purple otherwise:
            if len(self._groupby) == 2:
                self.options["color_palette"] = "Set3"
            else:
                self.options["color_palette"] = "RdPu"
        super(Visualizer, self).set_defaults()

        if self.percentage:
            self.options["label_x_axis"] = "Percentage"
        else:
            self.options["label_x_axis"] = "Frequency"

        session = options.cfg.main_window.Session

        if len(self._groupby) == 2:
            self.options["label_y_axis"] = session.translate_header(self._groupby[0])
            self.options["label_legend"] = session.translate_header(self._groupby[1])
        else:
            self.options["label_legend"] = session.translate_header(self._groupby[0])
            if self.percentage:
                self.options["label_y_axis"] = ""
            else:
                self.options["label_y_axis"] = session.translate_header(self._groupby[0])
项目:coquery    作者:gkunter    | 项目源码 | 文件源码
def set_defaults(self):
        self.options["color_palette"] = "Paired"
        self.options["color_number"] = len(self._levels[0])
        super(Visualizer, self).set_defaults()
        self.options["label_x_axis"] = "Corpus position"
        if not self._levels or len(self._levels[0]) < 2:
            self.options["label_y_axis"] = ""
        else:
            self.options["label_y_axis"] = self._groupby[0]
项目:coquery    作者:gkunter    | 项目源码 | 文件源码
def set_defaults(self):
        if self.numerical_axes and False:
            if not self.options.get("color_number"):
                self.options["color_number"] = 1
            if not self.options.get("label_legend_columns"):
                self.options["label_legend_columns"] = 1
            if not self.options.get("color_palette"):
                self.options["color_palette"] = "Paired"
                self.options["color_number"] = 1
        else:
            if not self.options.get("color_number"):
                self.options["color_number"] = len(self._levels[-1])
            if not self.options.get("label_legend_columns"):
                self.options["label_legend_columns"] = 1
            if not self.options.get("color_palette"):
                if len(self._levels) == 0:
                    self.options["color_palette"] = "Paired"
                    self.options["color_number"] = 1
                elif len(self._levels[-1]) in (2, 4, 6):
                    self.options["color_palette"] = "Paired"
                elif len(self._groupby) == 2:
                    self.options["color_palette"] = "Paired"
                else:
                    self.options["color_palette"] = "RdPu"

        self.options["figure_font"] = (
            QtWidgets.QApplication.instance().font())

        if not self.options.get("color_palette_values"):
            self.set_palette_values(self.options["color_number"])
项目:coquery    作者:gkunter    | 项目源码 | 文件源码
def set_palette_values(self, n=None):
        """
        Set the color palette values to the specified number.
        """
        if not n:
            n = self.options["color_number"]
        else:
            self.options["color_number"] = n

        if self.options["color_palette"] != "custom":
            self.options["color_palette_values"] = sns.color_palette(
                self.options["color_palette"], n)
项目:coquery    作者:gkunter    | 项目源码 | 文件源码
def show_palette(self):
        self.ui.color_test_area.clear()
        #test_numbers = self.ui.spin_number.value()
        test_numbers = 12
        test_palette = sns.color_palette(self._palette_name, test_numbers)
        for i, (r, g, b)in enumerate(test_palette):
            item = QtWidgets.QListWidgetItem()
            self.ui.color_test_area.addItem(item)
            brush = QtGui.QBrush(QtGui.QColor(
                        int(r * 255), int(g * 255), int(b * 255)))
            item.setBackground(brush)
项目:coquery    作者:gkunter    | 项目源码 | 文件源码
def test_palette(self):
        if self.palette_name == "custom":
            palette = self.custom_palette
        else:
            palette = sns.color_palette(self.palette_name, int(self.ui.spin_number.value()))
        self.ui.color_test_area.clear()
        for color in palette:
            item = CoqColorItem(color)
            self.ui.color_test_area.addItem(item)
项目:physt    作者:janpipek    | 项目源码 | 文件源码
def _get_cmap(kwargs):
    """Get the colour map for plots that support it.

    Parameters
    ----------
    cmap : str or colors.Colormap or list of colors
        A map or an instance of cmap. This can also be a seaborn palette
        (if seaborn is installed).

    Returns
    -------
    colors.Colormap
    """
    from matplotlib.colors import ListedColormap

    cmap = kwargs.pop("cmap", default_cmap)
    if isinstance(cmap, list):
        return ListedColormap(cmap)
    if isinstance(cmap, str):
        try:
            cmap = plt.get_cmap(cmap)
        except BaseException as exc:
            try:
                # Try to use seaborn palette
                import seaborn as sns
                sns_palette = sns.color_palette(cmap, n_colors=256)
                cmap = ListedColormap(sns_palette, name=cmap)
            except ImportError:
                raise exc
    return cmap
项目:tensorforce-benchmark    作者:reinforceio    | 项目源码 | 文件源码
def make_palette(self):
        if not self.palette:
            self.palette = sns.color_palette("husl", len(self.benchmarks))
项目:score_card_base_python    作者:zzstrwolf    | 项目源码 | 文件源码
def plot_br_chart(self,column):
        if type(self.woe_dicts[column].items()[0][0]) == str:
            woe_lists = sorted(self.woe_dicts[column].items(), key = self.sort_dict)
        else:
            woe_lists = sorted(self.woe_dicts[column].items(),key = lambda item:item[0])
        sns.set_style(rc={"axes.facecolor": "#EAEAF2",
                "axes.edgecolor": "#EAEAF2",
                "axes.linewidth": 1,
                "grid.color": "white",})
        tick_label = [i[0] for i in woe_lists]
        counts = [i[1][1] for i in woe_lists]
        br_data = [i[1][2] for i in woe_lists]
        x = range(len(counts))
        fig, ax1 = plt.subplots(figsize=(12,8))
        my_palette = sns.color_palette(n_colors=100)
        sns.barplot(x,counts,ax=ax1,palette=sns.husl_palette(n_colors=20,l=.7))
        plt.xticks(x,tick_label,rotation = 30,fontsize=12)
        plt.title(column,fontsize=18)
        ax1.set_ylabel('count',fontsize=15)
        ax1.tick_params('y',direction='in',length=6, width=0.5, labelsize=12)
        #ax1.bar(x,counts,tick_label = tick_label,color = 'y',align = 'center')
        #ax1.bar(x,counts,color = 'y',align = 'center')

        ax2 = ax1.twinx()
        ax2.plot(x,br_data,color='black')
        ax2.set_ylabel('bad rate',fontsize=15)
        ax2.tick_params('y',direction='in',length=6, width=0.5, labelsize=12)
        plot_margin = 0.25
        x0, x1, y0, y1 = ax1.axis()
        ax1.axis((x0 - plot_margin,
              x1 + plot_margin,
              y0 - 0,
              y1 * 1.1))
        plt.show()
项目:score_card_base_python    作者:zzstrwolf    | 项目源码 | 文件源码
def save_br_chart(self, column, path):
        if type(self.woe_dicts[column].items()[0][0]) == str:
            woe_lists = sorted(self.woe_dicts[column].items(), key = self.sort_dict)
        else:
            woe_lists = sorted(self.woe_dicts[column].items(),key = lambda item:item[0])
        tick_label = [i[0] for i in woe_lists]
        counts = [i[1][1] for i in woe_lists]
        br_data = [i[1][2] for i in woe_lists]
        x = range(len(counts))
        fig, ax1 = plt.subplots(figsize=(12,8))
        my_palette = sns.color_palette(n_colors=100)
        sns.barplot(x,counts,ax=ax1,palette=sns.husl_palette(n_colors=20,l=.7))
        plt.xticks(x,tick_label,rotation = 30,fontsize=12)
        plt.title(column,fontsize=18)
        ax1.set_ylabel('count',fontsize=15)
        ax1.tick_params('y',labelsize=12)
        ax2 = ax1.twinx()
        ax2.plot(x,br_data,color='black')
        ax2.set_ylabel('bad rate',fontsize=15)
        ax2.tick_params('y',labelsize=12)
        plot_margin = 0.25
        x0, x1, y0, y1 = ax1.axis()
        ax1.axis((x0 - plot_margin,
              x1 + plot_margin,
              y0 - 0,
              y1 * 1.1))
        plt.savefig(path)
项目:mriqc    作者:poldracklab    | 项目源码 | 文件源码
def plot(self):
        nconfounds = len(self.confounds)
        nspikes = len(self.spikes)
        nrows = 1 + nconfounds + nspikes

        # Create grid
        grid = mgs.GridSpec(nrows, 1, wspace=0.0, hspace=0.2,
                            height_ratios=[1] * (nrows - 1) + [3.5])

        grid_id = 0
        for tsz, name, iszs in self.spikes:
            spikesplot(tsz, title=name, outer_gs=grid[grid_id], tr=self.tr,
                       zscored=iszs)
            grid_id += 1

        if self.confounds:
            palette = color_palette("husl", nconfounds)

        for i, (tseries, kwargs) in enumerate(self.confounds):
            confoundplot(
                tseries, grid[grid_id], tr=self.tr, color=palette[i],
                **kwargs)
            grid_id += 1

        fmricarpetplot(self.func_data, self.seg_data,
                       grid[-1], tr=self.tr)

        setattr(self, 'grid', grid)
        # spikesplot_cb([0.7, 0.78, 0.2, 0.008])
项目:Comparative-Annotation-Toolkit    作者:ComparativeGenomicsToolkit    | 项目源码 | 文件源码
def generic_unstacked_barplot(df, pdf, title_string, legend_labels, ylabel, names, box_label,
                              bbox_to_anchor=(1.12, 0.7)):
    fig, ax = plt.subplots()
    bars = []
    shorter_bar_width = bar_width / len(df)
    for i, (_, d) in enumerate(df.iterrows()):
        bars.append(ax.bar(np.arange(len(df.columns)) + shorter_bar_width * i, d, shorter_bar_width,
                           color=sns.color_palette()[i], linewidth=0.0))
    _generic_histogram(bars, legend_labels, title_string, pdf, ax, fig, ylabel, names, box_label, bbox_to_anchor)
项目:Comparative-Annotation-Toolkit    作者:ComparativeGenomicsToolkit    | 项目源码 | 文件源码
def generic_stacked_barplot(df, pdf, title_string, legend_labels, ylabel, names, box_label, bbox_to_anchor=(1.12, 0.7)):
    fig, ax = plt.subplots()
    bars = []
    cumulative = np.zeros(len(df.columns))
    color_palette = choose_palette(legend_labels)
    for i, (_, d) in enumerate(df.iterrows()):
        bars.append(ax.bar(np.arange(len(df.columns)), d, bar_width, bottom=cumulative,
                           color=color_palette[i], linewidth=0.0))
        cumulative += d
    _generic_histogram(bars, legend_labels, title_string, pdf, ax, fig, ylabel, names, box_label, bbox_to_anchor)


###
# Shared functions
###
项目:Comparative-Annotation-Toolkit    作者:ComparativeGenomicsToolkit    | 项目源码 | 文件源码
def choose_palette(ordered_genomes):
    """choose palette in cases where genomes get different colors"""
    if len(ordered_genomes) <= 6:
        return sns.color_palette()
    else:
        return sns.color_palette("Set2", len(ordered_genomes))
项目:augur    作者:nextstrain    | 项目源码 | 文件源码
def plot_frequencies(flu, gene, mutation=None, plot_regions=None, all_muts=False, ax=None, **kwargs):
    import seaborn as sns
    sns.set_style('whitegrid')
    cols = sns.color_palette()
    linestyles = ['-', '--', '-.', ':']
    if plot_regions is None:
        plot_regions=regions
    pivots = flu.pivots
    if ax is None:
        plt.figure()
        ax=plt.subplot(111)
    if type(mutation)==int:
        mutations = [x for x,freq in flu.mutation_frequencies[('global', gene)].iteritems()
                     if (x[0]==mutation)&(freq[0]<0.5 or all_muts)]
    elif mutation is not None:
        mutations = [mutation]
    else:
        mutations=None

    if mutations is None:
        for ri, region in enumerate(plot_regions):
            count=flu.mutation_frequency_counts[region]
            plt.plot(pivots, count, c=cols[ri%len(cols)], label=region)
    else:
        print("plotting mutations", mutations)
        for ri,region in enumerate(plot_regions):
            for mi,mut in enumerate(mutations):
                if mut in flu.mutation_frequencies[(region, gene)]:
                    freq = flu.mutation_frequencies[(region, gene)][mut]
                    err = flu.mutation_frequency_confidence[(region, gene)][mut]
                    c=cols[ri%len(cols)]
                    label_str = str(mut[0]+1)+mut[1]+', '+region
                    plot_trace(ax, pivots, freq, err, c=c,
                        ls=linestyles[mi%len(linestyles)],label=label_str, **kwargs)
                else:
                    print(mut, 'not found in region',region)
    ax.ticklabel_format(useOffset=False)
    ax.legend(loc=2)
项目:augur    作者:nextstrain    | 项目源码 | 文件源码
def plot_sequence_count(flu, fname=None, fs=12):
    # make figure with region counts
    import seaborn as sns
    date_bins = pivots_to_dates(flu.pivots)
    sns.set_style('ticks')
    region_label = {'global': 'Global', 'NA': 'N America', 'AS': 'Asia', 'EU': 'Europe', 'OC': 'Oceania'}
    regions_abbr = ['global', 'NA', 'AS', 'EU', 'OC']
    region_colors = {r:col for r, col in zip(regions_abbr,
                                             sns.color_palette(n_colors=len(regions_abbr)))}
    fig, ax = plt.subplots(figsize=(8, 3))
    count_by_region = flu.mutation_frequency_counts
    drop = 3
    tmpcounts = np.zeros(len(flu.pivots[drop:]))
    plt.bar(date_bins[drop:], count_by_region['global'][drop:], width=18, \
            linewidth=0, label="Other", color="#bbbbbb", clip_on=False)
    for region in region_groups:
        if region!='global':
            plt.bar(date_bins[drop:], count_by_region[region][drop:],
                    bottom=tmpcounts, width=18, linewidth=0,
                    label=region_label[region], color=region_colors[region], clip_on=False)
            tmpcounts += count_by_region[region][drop:]
    make_date_ticks(ax, fs=fs)
    ax.set_ylabel('Sample count')
    ax.legend(loc=3, ncol=1, bbox_to_anchor=(1.02, 0.53))
    plt.subplots_adjust(left=0.1, right=0.82, top=0.94, bottom=0.22)
    sns.despine()
    if fname is not None:
        plt.savefig(fname)
项目:augur    作者:nextstrain    | 项目源码 | 文件源码
def plot_prediction(self):
        '''
        plots the global frequencies, the predicted frequencies, and the frequencies
        in the short interval used for learning.
        '''
        from matplotlib import pyplot as plt
        import seaborn as sns
        fig, axs = plt.subplots(1,2, figsize=(12,6))

        axs[0].plot(self.t_cut*np.ones(2), [0,1], lw=3, alpha=0.3, c='k', ls='--')
        axs[0].plot(self.current_prediction_interval[1]*np.ones(2), [0,1], lw=3, alpha=0.3, c='k')

        train_pivots = self.train_frequencies[self.current_prediction_interval][0]
        train_freqs = self.train_frequencies[self.current_prediction_interval][1]
        cols = sns.color_palette()
        future_pivots = self.global_pivots>train_pivots[-1]
        for node in self.predictions:
            if np.max(self.predictions[node][self.global_pivots>train_pivots[0]])>0.02:
                #print(self.predictions[t_cut_val][node])
                axs[0].plot(self.global_pivots[future_pivots],
                            self.predictions[node][future_pivots], ls='--', c=cols[node.clade%6])
                axs[0].plot(self.global_pivots, self.global_freqs[node.clade], ls='-', c=cols[node.clade%6])
                axs[0].plot(train_pivots, train_freqs[node.clade], ls='-.', c=cols[node.clade%6])

        axs[0].set_xlim(train_pivots[0]-2, train_pivots[-1]+2)
        dev = self.prediction_error()
        dev[~future_pivots]=0.0
        axs[1].plot(self.global_pivots, dev)
        axs[1].set_xlim(train_pivots[0], train_pivots[-1]+2)
        axs[1].set_ylim(0, 3)
项目:harpreif    作者:harpribot    | 项目源码 | 文件源码
def scatter(x, colors):
        # We choose a color palette with seaborn.
        palette = np.array(sea.color_palette("hls", 258))

        # We create a scatter plot.
        f = plt.figure(figsize=(8, 8))
        ax = plt.subplot(aspect='equal')
        sc = ax.scatter(x[:, 0], x[:, 1], lw=0, s=40,
                        c=palette[colors.astype(np.int)])
        plt.xlim(-25, 25)
        plt.ylim(-25, 25)
        ax.axis('off')
        ax.axis('tight')

        # We add the labels for each digit.
        txts = []
        for i in range(10):
            # Position of each label.
            xtext, ytext = np.median(x[colors == i, :], axis=0)
            txt = ax.text(xtext, ytext, str(i), fontsize=24)
            txt.set_path_effects([
                patheffects.Stroke(linewidth=5, foreground="w"),
                patheffects.Normal()])
            txts.append(txt)

        plt.show()
        return f, ax, sc, txts
项目:fake_news    作者:bmassman    | 项目源码 | 文件源码
def word_count_by_label(articles: pd.DataFrame):
    """Show graph of word counts by article label."""
    palette = sns.color_palette(palette='hls', n_colors=2)
    true_news_wc = articles[articles['labels'] == 0]['word_count']
    fake_news_wc = articles[articles['labels'] == 1]['word_count']
    sns.kdeplot(true_news_wc, bw=3, color=palette[0], label='True News')
    sns.kdeplot(fake_news_wc, bw=3, color=palette[1], label='Fake News')
    sns.plt.legend()
    sns.plt.show()
项目:facenet_pytorch    作者:liorshk    | 项目源码 | 文件源码
def visual_feature_space(features, labels, num_classes, name_dict):
    num = len(labels)

    title_font = {'fontname':'Arial', 'size':'20', 'color':'black', 'weight':'normal',
              'verticalalignment':'bottom'} # Bottom vertical alignment for more space
    axis_font = {'fontname':'Arial', 'size':'20'}

    # draw
    palette = np.array(sns.color_palette("hls", num_classes))

    # We create a scatter plot.
    f = plt.figure(figsize=(8, 8))
    ax = plt.subplot(aspect='equal')
    sc = ax.scatter(features[:,0], features[:,1], lw=0, s=40,
                    c=palette[labels.astype(np.int)])
    # ax.axis('off')
    # ax.axis('tight')

    # We add the labels for each digit.
    txts = []
    for i in range(num_classes):
        # Position of each label.
        xtext, ytext = np.median(features[labels == i, :], axis=0)
        txt = ax.text(xtext, ytext, name_dict[i])
        txt.set_path_effects([
            PathEffects.Stroke(linewidth=5, foreground="w"),
            PathEffects.Normal()])
        txts.append(txt)
    ax.set_xlabel('Activation of the 1st neuron', **axis_font)
    ax.set_ylabel('Activation of the 2nd neuron', **axis_font)
    ax.set_title('softmax_loss + center_loss', **title_font)
    ax.set_axis_bgcolor('grey')
    f.savefig('center_loss.png')
    plt.show()
    return f, ax, sc, txts