Python seaborn 模块,heatmap() 实例源码

我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用seaborn.heatmap()

项目:DeepLearning_PlantDiseases    作者:MarkoArsenovic    | 项目源码 | 文件源码
def Occlusion_exp(image,occluding_size,occluding_stride,model,preprocess,classes,groundTruth):    
    img = np.copy(image)
    height, width,_= img.shape
    output_height = int(math.ceil((height-occluding_size)/occluding_stride+1))
    output_width = int(math.ceil((width-occluding_size)/occluding_stride+1))
    ocludedImages=[]
    for h in range(output_height):
        for w in range(output_width):
            #occluder region
            h_start = h*occluding_stride
            w_start = w*occluding_stride
            h_end = min(height, h_start + occluding_size)
            w_end = min(width, w_start + occluding_size)

            input_image = copy.copy(img)
            input_image[h_start:h_end,w_start:w_end,:] =  0
            ocludedImages.append(preprocess(Image.fromarray(input_image)))

    L = np.empty(output_height*output_width)
    L.fill(groundTruth)
    L = torch.from_numpy(L)
    tensor_images = torch.stack([img for img in ocludedImages])
    dataset = torch.utils.data.TensorDataset(tensor_images,L) 
    dataloader = torch.utils.data.DataLoader(dataset,batch_size=5,shuffle=False, num_workers=8) 

    heatmap=np.empty(0)
    model.eval()
    for data in dataloader:
        images, labels = data

        if use_gpu:
            images, labels = (images.cuda()), (labels.cuda(async=True))

        outputs = model(Variable(images))
        m = nn.Softmax()
        outputs=m(outputs)
        if use_gpu:   
            outs=outputs.cpu()
        heatmap = np.concatenate((heatmap,outs[0:outs.size()[0],groundTruth].data.numpy()))

    return heatmap.reshape((output_height, output_width))
项目:activity-browser    作者:LCA-ActivityBrowser    | 项目源码 | 文件源码
def __init__(self, parent, mlca, width=6, height=6, dpi=100):
        figure = Figure(figsize=(width, height), dpi=dpi, tight_layout=True)
        axes = figure.add_subplot(111)

        super(LCAResultsPlot, self).__init__(figure)
        self.setParent(parent)
        activity_names = [format_activity_label(next(iter(f.keys()))) for f in mlca.func_units]
        # From https://stanford.edu/~mwaskom/software/seaborn/tutorial/color_palettes.html
        cmap = sns.cubehelix_palette(8, start=.5, rot=-.75, as_cmap=True)
        hm = sns.heatmap(
            # mlca.results / np.average(mlca.results, axis=0), # Normalize to get relative results
            mlca.results,
            annot=True,
            linewidths=.05,
            cmap=cmap,
            xticklabels=["\n".join(x) for x in mlca.methods],
            yticklabels=activity_names,
            ax=axes,
            square=False,
        )
        hm.tick_params(labelsize=8)

        self.setMinimumSize(self.size())
        # sns.set_context("notebook")
项目:kmeans-service    作者:MAYHEM-Lab    | 项目源码 | 文件源码
def plot_correlation_fig(data):
    """
    Creates a correlation heat map for all columns in user data.

    Parameters
    ----------
    data: Pandas DataFrame
        User data file as a Pandas DataFrame

    Returns
    -------
    Matplotlib Figure object.
    """
    sns.set(context='talk', style='white')
    fig = plt.figure()
    sns.heatmap(data.corr(), vmin=-1, vmax=1)
    plt.tight_layout()
    return fig
项目:LinearCorex    作者:gregversteeg    | 项目源码 | 文件源码
def plot_heatmaps(data, mis, column_label, cont, topk=30, prefix=''):
    cmap = sns.cubehelix_palette(as_cmap=True, light=.9)
    m, nv = mis.shape
    for j in range(m):
        inds = np.argsort(- mis[j, :])[:topk]
        if len(inds) >= 2:
            plt.clf()
            order = np.argsort(cont[:,j])
            subdata = data[:, inds][order].T
            subdata -= np.nanmean(subdata, axis=1, keepdims=True)
            subdata /= np.nanstd(subdata, axis=1, keepdims=True)
            columns = [column_label[i] for i in inds]
            sns.heatmap(subdata, vmin=-3, vmax=3, cmap=cmap, yticklabels=columns, xticklabels=False, mask=np.isnan(subdata))
            filename = '{}/heatmaps/group_num={}.png'.format(prefix, j)
            if not os.path.exists(os.path.dirname(filename)):
                os.makedirs(os.path.dirname(filename))
            plt.title("Latent factor {}".format(j))
            plt.yticks(rotation=0)
            plt.savefig(filename, bbox_inches='tight')
            plt.close('all')
            #plot_rels(data[:, inds], map(lambda q: column_label[q], inds), colors=cont[:, j],
            #          outfile=prefix + '/relationships/group_num=' + str(j), latent=labels[:, j], alpha=0.1)
项目:Optimus    作者:ironmussa    | 项目源码 | 文件源码
def correlation(self, vec_col, method="pearson"):
        """
        Compute the correlation matrix for the input dataset of Vectors using the specified method. Method
        mapped from  pyspark.ml.stat.Correlation.

        :param vec_col: The name of the column of vectors for which the correlation coefficient needs to be computed.
        This must be a column of the dataset, and it must contain Vector objects.
        :param method: String specifying the method to use for computing correlation. Supported: pearson (default),
        spearman.
        :return: Heatmap plot of the corr matrix using seaborn.
        """

        assert isinstance(method, str), "Error, method argument provided must be a string."

        assert method == 'pearson' or (
            method == 'spearman'), "Error, method only can be 'pearson' or 'sepearman'."

        cor = Correlation.corr(self._df, vec_col, method).head()[0].toArray()
        return sns.heatmap(cor, mask=np.zeros_like(cor, dtype=np.bool), cmap=sns.diverging_palette(220, 10,
                                                                                                   as_cmap=True))
项目:Quantrade    作者:quant-trade    | 项目源码 | 文件源码
def qindex_heatmap(broker):
    try:
        info = {"broker": broker, "symbol": "AI50", "period": "1440", \
            "system": "AI50", "direction": "longs"}
        filename = join(settings.DATA_PATH, 'portfolios', '{}_qndx'.format(broker))
        image_filename = filename_constructor(info=info, folder="heatmap")
        data = await df_multi_reader(filename=filename)

        info["direction"] = 1
        returns = await convert_to_perc(data=data.last('108M').LONG_PL, info=info)

        if not returns is None:
            returns.columns = ['LONG_PL']
            if (not isfile(image_filename)) | (datetime.fromtimestamp(getmtime(image_filename)) < \
                    (datetime.now() - timedelta(days=30))):
                await save_qindex_heatmap(data=returns, image_filename=image_filename)
        await make_yearly_returns(returns=returns, info=info)
    except Exception as err:
        print(colored.red("At qindex_heatmap {}".format(err)))
项目:ModelFlow    作者:yuezPrincetechs    | 项目源码 | 文件源码
def cor_df(data, cols=None, xticklabels=False, yticklabels=False, close=True):
    '''
    ??: ???????????
    ???: 
    data: ?????dataframe??
    cols: ?????list??????data????
    close: ????????
    ???: 
    cormat: ??????dataframe??
    heatmap: ????fig??
    '''
    if cols is None:
        cols=list(data.columns)
    corrmat = data[cols].corr()
    fig = plt.figure()
    ax = fig.add_subplot(111)
    sns.set(context='paper', font='monospace')
    sns.heatmap(corrmat, vmax=0.8, square=True, ax=ax, xticklabels=xticklabels, yticklabels=yticklabels)
    ax.set_title('Heatmap of Correlation Matrix')
    if close:
        plt.close('all')
    return corrmat, fig


#Distribution
项目:ModelFlow    作者:yuezPrincetechs    | 项目源码 | 文件源码
def heatmap(data,ax,xlabel=None,ylabel=None,xticklabels=None,yticklabels=None,title=None,fontsize=12):
    '''
    ??matplotlib.pyplot.pcolor?????
    ?????(pc,ax)???pc????matplotlib.pyplot.colorbar??????mappable?
    '''
    pc=ax.pcolor(data,cmap=plt.cm.Blues)
    if xlabel is not None:
        ax.set_xlabel(xlabel,fontsize=fontsize)
    if ylabel is not None:
        ax.set_ylabel(ylabel,fontsize=fontsize)
    ax.set_xticks(np.arange(data.shape[1])+0.5,minor=False)
    if xticklabels is not None:
        ax.set_xticklabels(xticklabels,minor=False,fontsize=fontsize)
    ax.set_yticks(np.arange(data.shape[0])+0.5,minor=False)
    if yticklabels is not None:
        ax.set_yticklabels(yticklabels,minor=False,fontsize=fontsize)
    if title is not None:
        ax.set_title(title,fontsize=fontsize)
    return pc,ax


#????X?Y????
项目:deep-learning-for-genomics    作者:chgroenbech    | 项目源码 | 文件源码
def plotKLdivergenceHeatmap(KL_all, name = None):

    print("Plotting KL-divergence heatmap (activity of latent units).")

    figure_name = "KL_divergence_heatmap"

    if name:
        figure_name = name + "/" + figure_name

    figure = pyplot.figure()
    axis = figure.add_subplot(1, 1, 1)

    KL_array = array(KL_all)

    print("Dimensions of KL-activations:")
    print(KL_array.shape)

    seaborn.heatmap(log(KL_array.T), xticklabels = True, yticklabels = False,
        cbar = True, center = None, square = True, ax = axis)

    axis.set_xlabel("Epoch")
    axis.set_ylabel("$log KL(p_i||q_i)$")

    data.saveFigure(figure, figure_name, no_spine = False)
项目:cgpm    作者:probcomp    | 项目源码 | 文件源码
def plot_heatmap(
        D, xordering=None, yordering=None, xticklabels=None,
        yticklabels=None, vmin=None, vmax=None, ax=None):
    import seaborn as sns
    D = np.copy(D)

    if ax is None:
        _, ax = plt.subplots()
    if xticklabels is None:
        xticklabels = np.arange(D.shape[0])
    if yticklabels is None:
        yticklabels = np.arange(D.shape[1])
    if xordering is not None:
        xticklabels = xticklabels[xordering]
        D = D[:,xordering]
    if yordering is not None:
        yticklabels = yticklabels[yordering]
        D = D[yordering,:]

    sns.heatmap(
        D, yticklabels=yticklabels, xticklabels=xticklabels,
        linewidths=0.2, cmap='BuGn', ax=ax, vmin=vmin, vmax=vmax)
    ax.set_xticklabels(xticklabels, rotation=90)
    ax.set_yticklabels(yticklabels, rotation=0)
    return ax
项目:astetik    作者:mikkokotila    | 项目源码 | 文件源码
def correlation(data,title=''):

    corr = data.corr(method='spearman')
    mask = np.zeros_like(corr)
    mask[np.triu_indices_from(mask)] = True

    sns.set(style="white")
    sns.set_context("notebook", font_scale=2, rc={"lines.linewidth": 0.3})

    rcParams['figure.figsize'] = 25, 12
    rcParams['font.family'] = 'Verdana'
    rcParams['figure.dpi'] = 300

    g = sns.heatmap(corr, mask=mask, linewidths=1, cmap="RdYlGn", annot=False)
    g.set_xticklabels(data,rotation=25,ha="right");
    plt.tick_params(axis='both', which='major', pad=15);
项目:python-machine-learning-book    作者:jeremyn    | 项目源码 | 文件源码
def visualize_housing_data(df):
    sns.set(style='whitegrid', context='notebook')
    cols = ['LSTAT', 'INDUS', 'NOX', 'RM', 'MEDV']

    sns.pairplot(df[cols], size=2.5)

    plt.show()

    correlation_matrix = np.corrcoef(df[cols].values.T)
    sns.set(font_scale=1.5)
    heatmap = sns.heatmap(
        correlation_matrix,
        cbar=True,
        annot=True,
        square=True,
        fmt='.2f',
        annot_kws={'size': 15},
        yticklabels=cols,
        xticklabels=cols,
    )

    plt.show()
项目:LeaguePredictor    作者:dgarwin    | 项目源码 | 文件源码
def sns_triangle(matrix, plt_title, only_class=None):

    sns.set(style="white")
    # Generate a mask for the upper triangle
    mask = np.zeros_like(matrix, dtype=np.bool)
    mask[np.triu_indices_from(mask)] = True

    # Set up the matplotlib figure
    f, ax = subplots(figsize=(11, 9))

    # Generate a custom diverging colormap
    cmap = sns.diverging_palette(220, 10, as_cmap=True)

    # Draw the heatmap with the mask and correct aspect ratio
    sns.heatmap(matrix.as_matrix(), mask=mask, cmap=cmap, vmax=.3,
                square=True, xticklabels=5, yticklabels=5,
                linewidths=.5, cbar_kws={"shrink": .5}, ax=ax)
    title(plt_title)
    xlabel('Preprocessed Features')
    ylabel('Preprocessed Features')
    if only_class is None:
        only_class = ''
    savefig('images/triangle'+only_class+'.png')
项目:fitbit-analyzer    作者:5agado    | 项目源码 | 文件源码
def plotSleepValueHeatmap(intradayStats, sleepValue=1):
    sns.set_context("poster")
    sns.set_style("darkgrid")

    xTicksDiv = 20
    #stepSize = int(len(xticks)/xTicksDiv)
    stepSize = 60
    xticks = [x for x in intradayStats.columns.values]
    keptticks = xticks[::stepSize]
    xticks = ['' for _ in xticks]
    xticks[::stepSize] = keptticks
    plt.figure(figsize=(16, 4.2))
    g = sns.heatmap(intradayStats.loc[sleepValue].reshape(1,-1))
    g.set_xticklabels(xticks, rotation=45)
    g.set_yticklabels([])
    g.set_ylabel(sleepStats.SLEEP_VALUES[sleepValue])
    plt.tight_layout()
    sns.plt.show()
项目:parisfellows_anonymize    作者:armgilles    | 项目源码 | 文件源码
def get_confusion(y_test, y_pred):
    cm = confusion_matrix(y_test, y_pred)
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    label_unique = y_test.unique()
#    #Graph Confusion Matrix
    tick_marks = np.arange(len(label_unique))
#    plt.figure(figsize=(8,6))
    sns.heatmap(cm_normalized, cmap='Greens',annot=True,linewidths=.5)
#    plt.title('confusion matrix')
    plt.xlabel('Predicted label')
    plt.ylabel('True label')
    plt.xticks(tick_marks + 0.5, list(label_unique))
    plt.yticks(tick_marks + 0.5,list(reversed(list(label_unique))) , rotation=0)
#    
#    plt.imshow(cm_normalized, interpolation='nearest', cmap='Greens')
#    plt.title('confusion matrix')
#    plt.colorbar()
#    tick_marks = np.arange(len(label_unique))
#    plt.xticks(tick_marks + 0.5, list(reversed(list(label_unique))))
#    plt.yticks(tick_marks + 0.5,list(label_unique) , rotation=0)
#    plt.tight_layout()
#    plt.ylabel('True label')
#    plt.xlabel('Predicted label')
项目:meucci-python    作者:returnandrisk    | 项目源码 | 文件源码
def plot_corr_heatmap(corr, labels, heading):

    sns.set(style="white")

    # Generate a mask for the upper triangle
    mask = np.zeros_like(corr, dtype=np.bool)
    mask[np.triu_indices_from(mask)] = True

    # Set up the matplotlib figure
    f, ax = plt.subplots(figsize=(8, 8))

    # Generate a custom diverging colormap
    cmap = sns.diverging_palette(220, 10, as_cmap=True)

    # Draw the heatmap with the mask and correct aspect ratio
    sns.heatmap(corr, mask=mask, cmap=cmap, vmax=.3,
                square=True, xticklabels=labels, yticklabels=labels,
                linewidths=.5, ax=ax, cbar_kws={"shrink": .5}, annot=True)
    ax.set_title(heading)
    plt.show()
项目:kaggle-review    作者:daxiongshu    | 项目源码 | 文件源码
def corr_heatmap(df,cols=None,name=None):
    sns.set()
    if cols is None:
        cols = [i for i in df.columns.values if df[i].dtype!='object']
    df = df[cols].corr()
    print(df.shape)
    ds = sns.heatmap(df, annot=False)
    plt.show()
    if name is not None:
        ds.get_figure().savefig(name)
项目:SNLI-Keras    作者:adamzjk    | 项目源码 | 文件源码
def plotHeatMap(df, psize=(8,8), filename='Heatmap'):
    ax = sns.heatmap(df, vmax=.85, square=True, cbar=False, annot=True)
    plt.xticks(rotation=40), plt.yticks(rotation=360)
    fig = ax.get_figure()
    fig.set_size_inches(psize)
    fig.savefig(filename)
    plt.clf()
项目:AutoSleepScorerDev    作者:skjerns    | 项目源码 | 文件源码
def plot_confusion_matrix(fname, conf_mat, target_names, 
                          title='', cmap='Blues', perc=True,figsize=[6,5],cbar=True):
    """Plot Confusion Matrix."""
    figsize = deepcopy(figsize)
    if cbar == False:
        figsize[0] = figsize[0] - 0.6
    c_names = []
    r_names = []
    if len(target_names) != len(conf_mat):
        target_names = [str(i) for  i in np.arange(len(conf_mat))]
    for i, label in enumerate(target_names):
        c_names.append(label + '\n(' + str(int(np.sum(conf_mat[:,i]))) + ')')
        align = len(str(int(np.sum(conf_mat[i,:])))) + 3 - len(label)
        r_names.append('{:{align}}'.format(label, align=align) + '\n(' + str(int(np.sum(conf_mat[i,:]))) + ')')

    cm = conf_mat
    cm = 100* cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    df = pd.DataFrame(data=np.sqrt(cm), columns=c_names, index=r_names)
    if fname != '':plt.figure(figsize=figsize)
    g  = sns.heatmap(df, annot = cm if perc else conf_mat , fmt=".1f" if perc else ".0f",
                     linewidths=.5, vmin=0, vmax=np.sqrt(100), cmap=cmap, cbar=cbar,annot_kws={"size": 13})    
    g.set_title(title)
    if cbar:
        cbar = g.collections[0].colorbar
        cbar.set_ticks(np.sqrt(np.arange(0,100,20)))
        cbar.set_ticklabels(np.arange(0,100,20))
    g.set_ylabel('True sleep stage',fontdict={'fontsize' : 12, 'fontweight':'bold'})
    g.set_xlabel('Predicted sleep stage',fontdict={'fontsize' : 12, 'fontweight':'bold'})
#    plt.tight_layout()
    if fname!='':
        plt.tight_layout()
        g.figure.savefig(os.path.join('plots', fname))
项目:AutoSleepScorerDev    作者:skjerns    | 项目源码 | 文件源码
def plot_difference_matrix(fname, confmat1, confmat2, target_names, 
                          title='', cmap='Blues', perc=True,figsize=[5,4],cbar=True,
                          **kwargs):
    """Plot Confusion Matrix."""
    figsize = deepcopy(figsize)
    if cbar == False:
        figsize[0] = figsize[0] - 0.6

    cm1 = confmat1
    cm2 = confmat2
    cm1 = 100 * cm1.astype('float') / cm1.sum(axis=1)[:, np.newaxis]
    cm2 = 100 * cm2.astype('float') / cm2.sum(axis=1)[:, np.newaxis]
    cm = cm2 - cm1
    cm_eye = np.zeros_like(cm)
    cm_eye[np.eye(len(cm_eye), dtype=bool)] = cm.diagonal()
    df = pd.DataFrame(data=cm_eye, columns=target_names, index=target_names)
    plt.figure(figsize=figsize)
    g  = sns.heatmap(df, annot=cm, fmt=".1f" ,
                     linewidths=.5, vmin=-10, vmax=10, 
                     cmap='coolwarm_r',annot_kws={"size": 13},cbar=cbar,**kwargs)#sns.diverging_palette(20, 220, as_cmap=True))    
    g.set_title(title)
    g.set_ylabel('True sleep stage',fontdict={'fontsize' : 12, 'fontweight':'bold'})
    g.set_xlabel('Predicted sleep stage',fontdict={'fontsize' : 12, 'fontweight':'bold'})
    plt.tight_layout()

    g.figure.savefig(os.path.join('plots', fname))
项目:activity-browser    作者:LCA-ActivityBrowser    | 项目源码 | 文件源码
def __init__(self, parent, data, labels, width=6, height=6, dpi=100):
        figure = Figure(figsize=(width, height), dpi=dpi, tight_layout=True)
        axes = figure.add_subplot(111)

        super(CorrelationPlot, self).__init__(figure)
        self.setParent(parent)

        sns.set(style="darkgrid")

        corr = data
        # cmap = sns.diverging_palette(220, 10, as_cmap=True)
        # corrplot(data, names=labels, annot=True, sig_stars=False,
        #      diag_names=True, cmap=cmap, ax=axes, cbar=True)

        df = pd.DataFrame(data=data, columns=labels)
        corr = df.corr()
        # Generate a mask for the upper triangle
        mask = np.zeros_like(corr, dtype=np.bool)
        mask[np.triu_indices_from(mask)] = True
        # Draw the heatmap with the mask and correct aspect ratio
        vmax = np.abs(corr.values[~mask]).max()
        # vmax = np.abs(corr).max()
        sns.heatmap(corr, mask=mask, cmap=plt.cm.PuOr, vmin=-vmax, vmax=vmax,
                    square=True, linecolor="lightgray", linewidths=1, ax=axes)
        for i in range(len(corr)):
            axes.text(i + 0.5, i + 0.5, corr.columns[i],
                      ha="center", va="center", rotation=0)
            for j in range(i + 1, len(corr)):
                s = "{:.3f}".format(corr.values[i, j])
                axes.text(j + 0.5, i + 0.5, s,
                          ha="center", va="center")
        axes.axis("off")
        # If uncommented, fills widget
        self.setSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding)
        self.updateGeometry()
        self.setMinimumSize(self.size())
项目:q2-diversity    作者:qiime2    | 项目源码 | 文件源码
def beta_rarefaction(output_dir: str, table: biom.Table, metric: str,
                     sampling_depth: int, iterations: int=10,
                     phylogeny: skbio.TreeNode=None,
                     correlation_method: str='spearman',
                     color_scheme: str='BrBG') -> None:
    if metric in phylogenetic_metrics():
        if phylogeny is None:
            raise ValueError("A phylogenetic metric (%s) was requested, "
                             "but a phylogenetic tree was not provided. "
                             "Phylogeny must be provided when using a "
                             "phylogenetic diversity metric." % metric)
        beta_func = functools.partial(beta_phylogenetic, phylogeny=phylogeny)
    else:
        beta_func = beta

    distance_matrices = _get_multiple_rarefaction(
        beta_func, metric, iterations, table, sampling_depth)

    sm_df = skbio.stats.distance.pwmantel(
        distance_matrices, method=correlation_method, permutations=0,
        strict=True)
    sm = sm_df[['statistic']]  # Drop all other DF columns
    sm = sm.unstack(level=0)  # Reshape for seaborn

    test_statistics = {'spearman': "Spearman's rho", 'pearson': "Pearson's r"}
    ax = sns.heatmap(
        sm, cmap=color_scheme, vmin=-1.0, vmax=1.0, center=0.0, annot=False,
        square=True, xticklabels=False, yticklabels=False,
        cbar_kws={'ticks': [1, 0.5, 0, -0.5, -1],
                  'label': test_statistics[correlation_method]})
    ax.set(xlabel='Iteration', ylabel='Iteration',
           title='Mantel correlation between iterations')
    ax.get_figure().savefig(os.path.join(output_dir, 'heatmap.svg'))

    similarity_mtx_fp = os.path.join(output_dir,
                                     'rarefaction-iteration-correlation.tsv')
    sm_df.to_csv(similarity_mtx_fp, sep='\t')

    index_fp = os.path.join(TEMPLATES, 'beta_rarefaction_assets', 'index.html')
    q2templates.render(index_fp, output_dir)
项目:sourcetracker2    作者:biota    | 项目源码 | 文件源码
def plot_heatmap(mpm, cm=plt.cm.viridis, xlabel='Sources', ylabel='Sinks',
                 title='Mixing Proportions (as Fraction)'):
    '''Make a basic mixing proportion histogram.'''
    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1)
    sns.heatmap(mpm, vmin=0, vmax=1.0, cmap=cm, annot=True, linewidths=.5,
                ax=ax)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_title(title)
    return fig, ax
项目:palladio    作者:slipguru    | 项目源码 | 文件源码
def save_heatmap(df, title, tag):
    """Create and save heatmaps."""
    sns.plt.figure()
    sns.plt.clf()
    filename = title+tag+".pdf"
    sns.heatmap(df, cmap="YlGn", annot=True, fmt='1.2f')
    sns.plt.title(title)
    sns.plt.ylabel(r'Number of samples $n$')
    sns.plt.xlabel(r'Number of dimensions $d$')
    sns.plt.savefig(filename)
    print("\t{} saved".format(filename))
项目:palladio    作者:slipguru    | 项目源码 | 文件源码
def make_heatmaps(collection, tag, idx, cols):
    """Generate heatmaps from dictionaries."""
    # Heatmaps containers
    acc = list()
    bacc = list()
    f1 = list()
    prec = list()
    rcll = list()
    for i, n in enumerate(sorted(collection.keys())):
        # Empty Rows of the heatmap
        n_acc = list()
        n_bacc = list()
        n_f1 = list()
        n_prec = list()
        n_rcll = list()
        for j, d in enumerate(sorted(collection[n])):  # fill columns
            n_acc.append(collection[n][d]['acc'])
            n_bacc.append(collection[n][d]['bacc'])
            n_f1.append(collection[n][d]['f1'])
            n_prec.append(collection[n][d]['prec'])
            n_rcll.append(collection[n][d]['rcll'])
        # Store filled rows
        acc.append(n_acc)
        bacc.append(n_bacc)
        f1.append(n_f1)
        prec.append(n_prec)
        rcll.append(n_rcll)

    # From lists of lists to numpy arrays
    acc = pd.DataFrame(data=np.array(acc), index=idx, columns=cols)
    bacc = pd.DataFrame(data=np.array(bacc), index=idx, columns=cols)
    f1 = pd.DataFrame(data=np.array(f1), index=idx, columns=cols)
    prec = pd.DataFrame(data=np.array(prec), index=idx, columns=cols)
    rcll = pd.DataFrame(data=np.array(rcll), index=idx, columns=cols)

    # Save heatmaps
    save_heatmap(acc, 'Accuracy', tag)
    save_heatmap(bacc, 'Balanced Accuracy', tag)
    save_heatmap(f1, 'F1', tag)
    save_heatmap(prec, 'Precision', tag)
    save_heatmap(rcll, 'Recall', tag)
项目:guesswhat    作者:GuessWhatGame    | 项目源码 | 文件源码
def __init__(self, path, games, logger, suffix):
        super(SuccessPosition, self).__init__(path, self.__class__.__name__, suffix)
        x_bin = 7
        y_bin = 7

        success_sum = np.zeros((x_bin+1, y_bin+1))
        total_sum = np.zeros((x_bin+1, y_bin+1))

        for game in games:

            bbox = game.object.bbox
            picture = game.image

            x = int(bbox.x_center / picture.width * x_bin)
            y = int(bbox.y_center / picture.height * y_bin)

            total_sum[x][y] += 1.0

            if game.status == "success":
                success_sum[x][y] += 1.0

        ratio = 1.0 * success_sum / total_sum


        sns.set(style="whitegrid")


        # Draw the heatmap with the mask and correct aspect ratio
        f = sns.heatmap(ratio, robust=True, linewidths=.5, cbar_kws={"label" : "% Success"}, xticklabels=False, yticklabels=False)
        f.set_xlabel("normalized image width", {'size':'14'})
        f.set_ylabel("normalized image height", {'size':'14'})
        f.legend(loc="upper left", fontsize='x-large')
项目:Odin    作者:JamesBrofos    | 项目源码 | 文件源码
def monthly_returns(self, fund, ax=None):
        if ax is None:
            ax = plt.gca()

        # Compute the returns on a month-over-month basis.
        history = fund.history
        monthly_ret = self.__aggregate_returns(history, 'monthly')
        monthly_ret = monthly_ret.unstack()
        monthly_ret = np.round(monthly_ret, 3)
        monthly_ret.rename(
            columns={1: 'Jan', 2: 'Feb', 3: 'Mar', 4: 'Apr',
                     5: 'May', 6: 'Jun', 7: 'Jul', 8: 'Aug',
                     9: 'Sep', 10: 'Oct', 11: 'Nov', 12: 'Dec'},
            inplace=True
        )

        # Create a heatmap showing the month-over-month returns of the portfolio
        # or the fund.
        sns.heatmap(
            monthly_ret.fillna(0) * 100.0, annot=True, fmt="0.1f",
            annot_kws={"size": 12}, alpha=1.0, center=0.0, cbar=False,
            cmap=cm.RdYlGn, ax=ax
        )
        ax.set_title('Monthly Returns (%)', fontweight='bold')
        ax.set_ylabel('')
        ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
        ax.set_xlabel('')

        return ax
项目:NLP-JD    作者:ZexinYan    | 项目源码 | 文件源码
def show_heat_map(self):
            pd.set_option('precision', 2)
            plt.figure(figsize=(20, 6))
            sns.heatmap(self.data.corr(), square=True)
            plt.xticks(rotation=90)
            plt.yticks(rotation=360)
            plt.suptitle("Correlation Heatmap")
            plt.show()
项目:johnson-county-ddj-public    作者:dssg    | 项目源码 | 文件源码
def plot_deviations(self, feature_column):
        """ Plots deviations from expected distributions of features within each
        predicted class.

        :param feature_column: name of the column on which to plot distributions
        :type feature_column: str
        :returns: heatmap of deviations
        :rtype: matplotlib figure
        """
        expected_proportions = self.get_distribution_by_class(
            feature_column, self.model['labelling'][0], True)

        observed_proportions = self.get_distribution_by_class(
            feature_column, 'y_pred', True)

        observed_values = self.get_distribution_by_class(
            feature_column, 'y_pred', False)

        proportion_deviation = ((observed_proportions - expected_proportions) / 
                                expected_proportions)

        deviation_plot = sns.heatmap(proportion_deviation, cmap = 'RdBu_r',
                                     vmin = -1, vmax = 1,
                                     annot = observed_values, fmt = 'g')

        deviation_plot.set(xlabel = feature_column, ylabel = 'predicted class',
                           yticklabels = reversed(self.labels))

        return(deviation_plot)
项目:Quantrade    作者:quant-trade    | 项目源码 | 文件源码
def save_qindex_heatmap(data, image_filename):
    try:
        monthly_ret = await aggregate_returns(returns=data, convert_to='monthly')
        monthly_ret = monthly_ret.unstack()
        monthly_ret = round(monthly_ret, 3)
        monthly_ret.rename(
            columns={1: 'Jan', 2: 'Feb', 3: 'Mar', 4: 'Apr',
                     5: 'May', 6: 'Jun', 7: 'Jul', 8: 'Aug',
                     9: 'Sep', 10: 'Oct', 11: 'Nov', 12: 'Dec'},
            inplace=True
        )
        ax = plt.gca()

        sns.heatmap(
            monthly_ret.fillna(0), # * 100.0,
            annot=True,
            fmt="0.1f",
            annot_kws={"size": 8},
            alpha=1.0,
            center=0.0,
            cbar=False,
            cmap=cm.RdYlGn,
            ax=ax)
        ax.set_title('A.I. Returns, %', fontweight='bold')

        plt.savefig(image_filename)
        plt.close()
        if settings.SHOW_DEBUG:
            print(colored.green("Wrote heatmap image for {}\n".format(image_filename)))
    except Exception as err:
        print(colored.red("At save_qindex_heatmap {}".format(err)))
项目:Quantrade    作者:quant-trade    | 项目源码 | 文件源码
def write_h(image_filename, data):
    try:
        monthly_ret = await aggregate_returns(returns=data, convert_to='monthly')
        monthly_ret = monthly_ret.unstack()
        monthly_ret = round(monthly_ret, 3)
        monthly_ret.rename(
            columns={1: 'Jan', 2: 'Feb', 3: 'Mar', 4: 'Apr',
                     5: 'May', 6: 'Jun', 7: 'Jul', 8: 'Aug',
                     9: 'Sep', 10: 'Oct', 11: 'Nov', 12: 'Dec'},
            inplace=True
        )
        ax = plt.gca()

        sns.heatmap(
            monthly_ret.fillna(0), # * 100.0,
            annot=True,
            fmt="0.1f",
            annot_kws={"size": 8},
            alpha=1.0,
            center=0.0,
            cbar=False,
            cmap=cm.RdYlGn,
            ax=ax)
        ax.set_title('Returns heatmap, %', fontweight='bold')

        plt.savefig(image_filename)
        plt.close()
        if settings.SHOW_DEBUG:
            print(colored.green("Wrote heatmap image for {}\n".format(image_filename)))
    except Exception as err:
        print(colored.red("At write_heatmap {}".format(err)))
项目:Quantrade    作者:quant-trade    | 项目源码 | 文件源码
def save_heatmap(data, info):
    try:
        image_filename = filename_constructor(info=info, folder="heatmap")

        if (not isfile(image_filename)) | (datetime.fromtimestamp(getmtime(image_filename)) < \
                (datetime.now() - timedelta(days=30))):
            await write_h(image_filename=image_filename, data=data)
    except Exception as err:
        print(colored.red("At save_heatmap {}".format(err)))
项目:Quantrade    作者:quant-trade    | 项目源码 | 文件源码
def generate_monthly_heatmaps(loop):
    brokers = Brokers.objects.all()
    path_to = join(settings.DATA_PATH, "performance")
    filenames = multi_filenames(path_to_history=path_to)

    loop.run_until_complete(gather(*[make_heat_img(\
        path_to=path_to, filename=filename) for filename in filenames], \
        return_exceptions=True))

    #AI50 index heatmap
    loop.run_until_complete(gather(*[qindex_heatmap(broker=broker.slug) for broker in brokers],
        return_exceptions=True))
项目:word2vec_pipeline    作者:NIHOPA    | 项目源码 | 文件源码
def compute(self, config):

        INPUT_ITR = self.iterator_batch(self._iterator_mean_cluster_vectors())
        Z = self.cluster_affinity_states(INPUT_ITR, size=self.cluster_n)

        print("Initial affinity grouping", Z.shape)
        # print self.vocab_n, self.cluster_n

        INPUT_ITR = self.iterator_batch(Z)
        Z2 = self.cluster_affinity_states(INPUT_ITR, size=len(Z))

        print("Final affinity size", len(Z2))
        self.save(config, Z2)

        '''
        import seaborn as sns
        plt = sns.plt
        DZ2 = cdist(Z2,Z2,metric='cosine')
        sns.heatmap(DZ2,xticklabels=False, yticklabels=False,linewidths=0)
        sns.plt.figure()
        #plt.show()

        DZ = cdist(Z,Z,metric='cosine')
        sns.heatmap(DZ,xticklabels=False, yticklabels=False,linewidths=0)
        #sns.plt.figure()
        sns.plt.show()
        '''

        self.h5.close()
项目:bio_corex    作者:gregversteeg    | 项目源码 | 文件源码
def plot_heatmaps(data, labels, alpha, mis, column_label, cont, topk=20, prefix='', focus=''):
    cmap = sns.cubehelix_palette(as_cmap=True, light=.9)
    m, nv = mis.shape
    for j in range(m):
        inds = np.where(np.logical_and(alpha[j] > 0, mis[j] > 0.))[0]
        inds = inds[np.argsort(- alpha[j, inds] * mis[j, inds])][:topk]
        if focus in column_label:
            ifocus = column_label.index(focus)
            if not ifocus in inds:
                inds = np.insert(inds, 0, ifocus)
        if len(inds) >= 2:
            plt.clf()
            order = np.argsort(cont[:,j])
            subdata = data[:, inds][order].T
            subdata -= np.nanmean(subdata, axis=1, keepdims=True)
            subdata /= np.nanstd(subdata, axis=1, keepdims=True)
            columns = [column_label[i] for i in inds]
            sns.heatmap(subdata, vmin=-3, vmax=3, cmap=cmap, yticklabels=columns, xticklabels=False, mask=np.isnan(subdata))
            filename = '{}/heatmaps/group_num={}.png'.format(prefix, j)
            if not os.path.exists(os.path.dirname(filename)):
                os.makedirs(os.path.dirname(filename))
            plt.title("Latent factor {}".format(j))
            plt.savefig(filename, bbox_inches='tight')
            plt.close('all')
            #plot_rels(data[:, inds], list(map(lambda q: column_label[q], inds)), colors=cont[:, j],
            #          outfile=prefix + '/relationships/group_num=' + str(j), latent=labels[:, j], alpha=0.1)
项目:decoding-brain-challenge-2016    作者:alexandrebarachant    | 项目源码 | 文件源码
def plot_confusion_matrix(targets, predictions, target_names,
                          title='Confusion matrix', cmap="Blues"):
    """Plot Confusion Matrix."""
    cm = confusion_matrix(targets, predictions)
    cm = 100 * cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    df = pd.DataFrame(data=cm, columns=target_names, index=target_names)
    g = sns.heatmap(df, annot=True, fmt=".1f", linewidths=.5, vmin=0, vmax=100,
                    cmap=cmap)
    g.set_title(title)
    g.set_ylabel('True label')
    g.set_xlabel('Predicted label')
    return g
项目:waffle-reviewer    作者:gabraganca    | 项目源码 | 文件源码
def plot_activity(series, savename='activity.png'):
    """Plots the Reviewers' activity"""
    # Fills the time series
    ## Fill up to next staurday (end of the week)
    series = fill_week(series)
    ### Fill or truncate timeseries to suit the plot
    number_of_days = 371
    if series.shape[0] > number_of_days:
        # truncate to 371 days
        series = series[-number_of_days:]
    elif series.shape[0] < number_of_days:
        # Fill remaing values with zero
        series = fill_year(series)
        assert series.shape[0] == number_of_days

    # Obtain the months for the years' week
    months = series.index.map(lambda x: x.strftime('%b')).tolist()
    n_weekdays = 7
    # Split in weeks
    months = months[::n_weekdays]
    # replace the repeated months
    current_month = ''
    for n, month in enumerate(months):
        if month == current_month:
            months[n] = ''
        else:
            current_month = month

    # Plot
    fig, ax = plt.subplots()

    sns.heatmap(series.values.reshape(-1,n_weekdays).T, ax=ax,
                cmap='YlGn', cbar=False, linewidths=1, square=True,
                xticklabels=months,
                yticklabels=['','M', '', 'W', '', 'F', ''])

    ax.xaxis.tick_top()

    plt.savefig(savename, bbox_inches='tight')
项目:rl-rc-car    作者:harvitronix    | 项目源码 | 文件源码
def visualize_sensors(state):
    # Clear.
    sns.plt.clf()

    # Make a 2d list.
    cols = [state[0]]

    # Plot it.
    sns.heatmap(data=cols, cmap="Blues_r", yticklabels=False)

    # Draw it.
    sns.plt.draw()

    # Add a pause because you're supposed to.
    sns.plt.pause(0.05)
项目:monthly-returns-heatmap    作者:ranaroussi    | 项目源码 | 文件源码
def plot(returns,
         title="Monthly Returns (%)\n",
         title_color="black",
         title_size=14,
         annot_size=10,
         figsize=None,
         cmap='RdYlGn',
         cbar=True,
         square=False,
         is_prices=False,
         eoy=False):

    returns = get(returns, eoy=eoy, is_prices=is_prices)
    returns *= 100

    if figsize is None:
        size = list(plt.gcf().get_size_inches())
        figsize = (size[0], size[0] // 2)
        plt.close()

    fig, ax = plt.subplots(figsize=figsize)
    ax = sns.heatmap(returns, ax=ax, annot=True, center=0,
                     annot_kws={"size": annot_size},
                     fmt="0.2f", linewidths=0.5,
                     square=square, cbar=cbar, cmap=cmap)
    ax.set_title(title, fontsize=title_size,
                 color=title_color, fontweight="bold")

    fig.subplots_adjust(hspace=0)
    plt.yticks(rotation=0)
    plt.show()
    plt.close()
项目:tf_practice    作者:juho-lee    | 项目源码 | 文件源码
def test():
    saver.restore(sess, FLAGS.save_dir+'/model.ckpt')
    batch_x, _ = mnist.test.next_batch(batch_size)
    fig = plt.figure('original')
    plt.gray()
    plt.axis('off')
    plt.imshow(batchmat_to_tileimg(batch_x, (height, width), (10, 10)))
    fig.savefig(FLAGS.save_dir+'/original.png')

    fig = plt.figure('reconstructed')
    plt.gray()
    plt.axis('off')
    p_recon = sess.run(p, {x:batch_x})
    plt.imshow(batchmat_to_tileimg(p_recon, (height, width), (10, 10)))
    fig.savefig(FLAGS.save_dir+'/reconstructed.png')

    batch_w = np.zeros((n_fac*n_fac, n_fac))
    for i in range(n_fac):
        batch_w[i*n_fac:(i+1)*n_fac, i] = 1.0
    batch_z = np.random.normal(size=(n_fac*n_fac, n_lat))
    p_gen = sess.run(p, {w:batch_w, z:batch_z})
    I_gen = batchmat_to_tileimg(p_gen, (height, width), (n_fac, n_fac))
    fig = plt.figure('generated')
    plt.gray()
    plt.axis('off')
    plt.imshow(I_gen)
    fig.savefig(FLAGS.save_dir+'/generated.png')

    fig = plt.figure('factor activation heatmap')
    hist = np.zeros((10, n_fac))
    for i in range(mnist.test.num_examples):
        batch_x, batch_y = mnist.test.next_batch(batch_size)
        batch_w = sess.run(w, {x:batch_x})
        for i in range(batch_size):
            hist[batch_y[i], batch_w[i] > 0] += 1
    sns.heatmap(hist)
    fig.savefig(FLAGS.save_dir+'/feature_activation.png')

    plt.show()
项目:tf_practice    作者:juho-lee    | 项目源码 | 文件源码
def test():
    saver.restore(sess, FLAGS.save_dir+'/model.ckpt')
    batch_x = test_x[0:100]
    fig = plt.figure('original')
    plt.gray()
    plt.axis('off')
    plt.imshow(batchmat_to_tileimg(batch_x, (height, width), (10, 10)))
    fig.savefig(FLAGS.save_dir+'/original.png')

    fig = plt.figure('reconstructed')
    plt.gray()
    plt.axis('off')
    p_recon = sess.run(p, {x:batch_x})
    plt.imshow(batchmat_to_tileimg(p_recon, (height, width), (10, 10)))
    fig.savefig(FLAGS.save_dir+'/reconstructed.png')

    batch_w = np.zeros((n_fac*n_fac, n_fac))
    for i in range(n_fac):
        batch_w[i*n_fac:(i+1)*n_fac, i] = 1.0
    batch_z = np.random.normal(size=(n_fac*n_fac, n_lat))
    p_gen = sess.run(p, {w:batch_w, z:batch_z})
    I_gen = batchmat_to_tileimg(p_gen, (height, width), (n_fac, n_fac))
    fig = plt.figure('generated')
    plt.gray()
    plt.axis('off')
    plt.imshow(I_gen)
    fig.savefig(FLAGS.save_dir+'/generated.png')

    """
    fig = plt.figure('factor activation heatmap')
    hist = np.zeros((10, n_fac))
    for i in range(len(test_x)):
        batch_x = test_x[i*batch_size:(i+1)*batch_size]
        batch_w = sess.run(w, {x:batch_x})
        for i in range(batch_size):
            hist[batch_y[i], batch_w[i] > 0] += 1
    sns.heatmap(hist)
    fig.savefig(FLAGS.save_dir+'/feature_activation.png')
    """

    plt.show()
项目:robot-dream    作者:research-team    | 项目源码 | 文件源码
def build_column_key(column, neurotransmitter, dt=None, heatmap=None):
    return "column_{0}_{1}_{2}_{3}".format(column, get_neurotransmitter_name(neurotransmitter), dt, heatmap)
项目:robot-dream    作者:research-team    | 项目源码 | 文件源码
def build_layer_key(index, neurotransmitter, dt=None, heatmap=None):
    return "layer_{0}_{1}_{2}_{3}".format(index, get_neurotransmitter_name(neurotransmitter), dt, heatmap)
项目:robot-dream    作者:research-team    | 项目源码 | 文件源码
def set_flag_to_column(column, neurotransmitter, heatmap=False, dt=1, multimeter=False):
    for neurotransmitter in (Glu, GABA) if neurotransmitter == both else (neurotransmitter,):
        key = build_column_key(column, neurotransmitter, dt, heatmap)
        spike_detectors[key] = dict()
        if multimeter:
            multimeters[key] = dict()
        for layer in range(len(Cortex)):
            neuron_number = len(Cortex[layer][column][neurotransmitter])
            if multimeter:
                multimeters[key][layer] = nest.Create('multimeter', params=multimeter_param)
                nest.Connect(multimeters[key][layer], Cortex[layer][column][neurotransmitter][::neuron_number / N_volt])
            spike_detectors[key][layer] = nest.Create('spike_detector', params=detector_param)
            nest.Connect(Cortex[layer][column][neurotransmitter][:N_detect], spike_detectors[key][layer])
项目:robot-dream    作者:research-team    | 项目源码 | 文件源码
def set_flag_to_layer(layer, neurotransmitter=Glu, heatmap=True, dt=1, multimeter=False):
    for neurotransmitter in (Glu, GABA) if neurotransmitter == both else (neurotransmitter,):
        key = build_layer_key(layer, neurotransmitter, dt, heatmap)
        spike_detectors[key] = dict()
        if multimeter:
            multimeters[key] = dict()
        for column in range(column_number):
            neuron_number = len(Cortex[layer][column][neurotransmitter])
            if multimeter:
                multimeters[key][column] = nest.Create('multimeter', params=multimeter_param)
                nest.Connect(multimeters[key][column], Cortex[layer][column][neurotransmitter][::neuron_number / N_volt])
            spike_detectors[key][column] = nest.Create('spike_detector', params=detector_param)
            nest.Connect(Cortex[layer][column][neurotransmitter][:N_detect], spike_detectors[key][column])
项目:robot-dream    作者:research-team    | 项目源码 | 文件源码
def save_layer_data(key, value, isMultimeter=False):
    """

    :param key:
    :param value:
    :param isMultimeter:
    :return:
    """
    # Get parameters from string
    params = str(key).split("_")
    area = params[0]
    layer_name = get_layer_name(int(params[1]))
    neurotransmitter = params[2]

    parent_dir = "{0}_{1}[{2}]".format(area, layer_name, neurotransmitter)
    if not os.path.exists(parent_dir):
        os.mkdir(parent_dir)

    if isMultimeter:
        addres = create_subdir('voltage', parent_dir)
        for column, device in value.iteritems():
            nest.voltage_trace.from_device(device, title="Membrane potential in {0} column {1}".format(layer_name, column))
            plt.savefig("{0}/{1}.{2}".format(addres, column, image_format), dpi=dpi_n, format=image_format)
            plt.close()
    else:
        dt = int(params[3])
        heatmap = bool(params[4])
        addres = create_subdir('spikes', parent_dir)
        for column, device in value.iteritems():
            try:
                nest.raster_plot.from_device(device, hist=True, title="Spikes {0} column {1}".format(layer_name, column))
                plt.savefig("{0}/{1}.{2}".format(addres, column, image_format), dpi=dpi_n, format=image_format)
                plt.close()
            except nest.NESTError:
                print "From column {0} {1}[{2}] activity was not found".format(column, layer_name, neurotransmitter)
        if heatmap:
            addres = create_subdir('heatmap', parent_dir)
            heatmap_builder(addres, value, dt, isColumn=False)
项目:icing    作者:slipguru    | 项目源码 | 文件源码
def show_heatmap(filename):
    """Show confusion matrix given of a partis-generated tab-delimited db."""
    true_labels, estimated_labels = get_clones_real_estimated(filename)
    cm, rows, cols = confusion_matrix(true_labels, estimated_labels)
    df = pd.DataFrame(cm, index=rows, columns=cols)
    sns.heatmap(df)
    sns.plt.show()
项目:pyEPR    作者:zlatko-minev    | 项目源码 | 文件源码
def xarr_heatmap(fg, title = None, kwheat = {}, fmt = ('%.3f', '%.2f'), fig = None):
    ''' Needs seaborn and xarray'''
    fig = plt.figure() if fig == None  else fig
    df  = fg.to_pandas()
    # format indecies
    df.index   = [float(fmt[0]%x) for x in df.index]
    df.columns = [float(fmt[1]%x) for x in df.columns]
    import seaborn as sns
    ax = sns.heatmap(df, annot=True, **kwheat)
    ax.invert_yaxis()
    ax.set_title(title)
    ax.set_xlabel(fg.dims[1])
    ax.set_ylabel(fg.dims[0])
项目:deepcpg    作者:cangermueller    | 项目源码 | 文件源码
def plot_filter_heatmap(weights, filename=None):
    param_range = abs(weights).max()

    fig, ax = plt.subplots(figsize=(weights.shape[1], weights.shape[0]))
    sns.heatmap(weights, cmap='RdYlBu_r', linewidths=0.2, vmin=-param_range,
                vmax=param_range, ax=ax)
    ax.set_xticklabels(range(1, weights.shape[1] + 1))
    labels = [ALPHABET_R[i] for i in reversed(range(weights.shape[0]))]
    ax.set_yticklabels(labels, rotation='horizontal', size=10)
    if filename:
        plt.savefig(filename)
        plt.close()
项目:MLAB_Intuit    作者:rykard95    | 项目源码 | 文件源码
def generate_confusion_matrix(y_test, y_pred, labels, title, filename, show=False):
    cm = confusion_matrix(y_test, y_pred, labels=labels)
    df_cm = pd.DataFrame(cm, index=labels, columns=labels)
    plt.figure(figsize=(12,8))
    ax = sn.heatmap(df_cm, annot=True)
    plt.ylabel("Actual Label", fontsize=14, fontweight='bold')
    plt.xlabel("Predicted Label", fontsize=14, fontweight='bold')
    plt.title(title, fontsize=16, fontweight='bold')

    ttl = ax.title
    ttl.set_position([0.5, 1.03])
    plt.savefig(filename)

    if show:
        plt.show()