Python matplotlib.pyplot 模块,subplots() 实例源码

我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用matplotlib.pyplot.subplots()

项目:visual-search    作者:GYXie    | 项目源码 | 文件源码
def main():
    args.input_data_dir = os.path.abspath(args.input_data_dir)
    if not os.path.exists(args.output_data_dir):
        os.mkdir(args.output_data_dir)
    for dir_path, dir_names, file_names in os.walk(args.input_data_dir):
        if len(file_names) > 0:
            print(dir_path)
            rows = int(math.ceil(len(file_names) / 6.0))
            print(rows)
            fig, axes = plt.subplots(4, 12, subplot_kw={'xticks': [], 'yticks': []})
            fig.subplots_adjust(hspace=0.01, wspace=0.01)
            for ax, file_name in zip(axes.flat, file_names):
                print(file_name)
                img = imread(dir_path + '/' + file_name)
                ax.imshow(img)
                # ax.set_title(os.path.splitext(file_name)[0].replace('.227x227', ''))
            plt.savefig(args.output_data_dir + dir_path.replace(args.input_data_dir, '') + '.pdf')
项目:PersonalizedMultitaskLearning    作者:mitmedialab    | 项目源码 | 文件源码
def saveHintonPlot(self, matrix, num_tests, max_weight=None, ax=None):
        """Draw Hinton diagram for visualizing a weight matrix."""
        fig,ax = plt.subplots(1,1)

        if not max_weight:
            max_weight = 2**np.ceil(np.log(np.abs(matrix).max())/np.log(2))

        ax.patch.set_facecolor('gray')
        ax.set_aspect('equal', 'box')
        ax.xaxis.set_major_locator(plt.NullLocator())
        ax.yaxis.set_major_locator(plt.NullLocator())

        for (x, y), w in np.ndenumerate(matrix):
            color = 'white' if w > 0 else 'black'
            size = np.sqrt(np.abs(0.5*w/num_tests)) # Need to scale so that it is between 0 and 0.5
            rect = plt.Rectangle([x - size / 2, y - size / 2], size, size,
                                 facecolor=color, edgecolor=color)
            ax.add_patch(rect)

        ax.autoscale_view()
        ax.invert_yaxis()
        plt.savefig(self.figures_path + self.save_prefix + '-Hinton.eps')
        plt.close()
项目:fingerprint-securedrop    作者:freedomofpress    | 项目源码 | 文件源码
def plot_ROC(test_labels, test_predictions):
    fpr, tpr, thresholds = metrics.roc_curve(
        test_labels, test_predictions, pos_label=1)
    auc = "%.2f" % metrics.auc(fpr, tpr)
    title = 'ROC Curve, AUC = '+str(auc)
    with plt.style.context(('ggplot')):
        fig, ax = plt.subplots()
        ax.plot(fpr, tpr, "#000099", label='ROC curve')
        ax.plot([0, 1], [0, 1], 'k--', label='Baseline')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.legend(loc='lower right')
        plt.title(title)
    return fig
项目:nanoQC    作者:wdecoster    | 项目源码 | 文件源码
def per_base_sequence_content_and_quality(fqbin, qualbin, outdir, figformat):
    fig, axs = plt.subplots(2, 2, sharex='col', sharey='row')
    lines = plot_nucleotide_diversity(axs[0, 0], fqbin)
    plot_nucleotide_diversity(axs[0, 1], fqbin, invert=True)
    l_Q = plot_qual(axs[1, 0], qualbin)
    plot_qual(axs[1, 1], qualbin, invert=True)
    plt.setp([a.get_xticklabels() for a in axs[0, :]], visible=False)
    plt.setp([a.get_yticklabels() for a in axs[:, 1]], visible=False)
    for ax in axs[:, 1]:
        ax.set_ylabel('', visible=False)
    for ax in axs[0, :]:
        ax.set_xlabel('', visible=False)
    # Since axes are shared I should only invert once. Twice will restore the original axis order!
    axs[0, 1].invert_xaxis()
    plt.suptitle("Per base sequence content and quality")
    axl = fig.add_axes([0.4, 0.4, 0.2, 0.2])
    ax.plot()
    axl.axis('off')
    lines.append(l_Q)
    plt.legend(lines, ['A', 'T', 'G', 'C', 'Quality'], loc="center", ncol=5)
    plt.savefig(os.path.join(outdir, "PerBaseSequenceContentQuality." +
                             figformat), format=figformat, dpi=500)
项目:kaggle_dsb2017    作者:astoc    | 项目源码 | 文件源码
def get_masks(scans,masks_list):
    #%matplotlib inline
    scans1=scans.copy()
    maxv=255
    masks=np.zeros(shape=(scans.shape[0],1,img_rows,img_cols))
    for i_m in range(len(masks_list)):
        for i in range(-masks_list[i_m][3],masks_list[i_m][3]+1):
            for j in range(-masks_list[i_m][3],masks_list[i_m][3]+1):
                masks[masks_list[i_m][0],0,masks_list[i_m][2]+i,masks_list[i_m][1]+j]=1
        for i1 in range(-masks_list[i_m][3],masks_list[i_m][3]+1):
            scans1[masks_list[i_m][0],0,masks_list[i_m][2]+i1,masks_list[i_m][1]+masks_list[i_m][3]]=maxv=255
            scans1[masks_list[i_m][0],0,masks_list[i_m][2]+i1,masks_list[i_m][1]-masks_list[i_m][3]]=maxv=255
            scans1[masks_list[i_m][0],0,masks_list[i_m][2]+masks_list[i_m][3],masks_list[i_m][1]+i1]=maxv=255
            scans1[masks_list[i_m][0],0,masks_list[i_m][2]-masks_list[i_m][3],masks_list[i_m][1]+i1]=maxv=255
    for i in range(scans.shape[0]):
        print ('scan '+str(i))
        f, ax = plt.subplots(1, 2,figsize=(10,5))
        ax[0].imshow(scans1[i,0,:,:],cmap=plt.cm.gray)
        ax[1].imshow(masks[i,0,:,:],cmap=plt.cm.gray)
        plt.show()
    return(masks)
项目:soccerstan    作者:Torvaney    | 项目源码 | 文件源码
def plot_parameter(data, title, alpha=0.05, axes_colour='dimgray'):
    """ Plot 1-dimensional parameters. """
    fig, ax = plt.subplots(figsize=(8, 6))

    ax.hist(data, bins=50, normed=True, color='black', edgecolor='None')

    # Add title
    fig.suptitle(title, fontsize=16, color=axes_colour)
    # Add axis labels
    ax.set_xlabel('', fontsize=16, color=axes_colour)
    ax.set_ylabel('', fontsize=16, color=axes_colour)

    # Change axes colour
    ax.spines["bottom"].set_color(axes_colour)
    ax.spines["left"].set_color(axes_colour)
    ax.tick_params(colors=axes_colour)
    # Remove top and bottom spines
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    # Remove extra ticks
    ax.get_xaxis().tick_bottom()
    ax.get_yaxis().tick_left()

    return fig
项目:mazerunner    作者:lucasdavid    | 项目源码 | 文件源码
def main():
    Q = ModelStorage.load(MODEL_NAME)
    Q_ = (Q - Q.mean()) / (Q.max() - Q.min())

    fig, ax = plt.subplots()
    heatmap = ax.pcolor(Q_, cmap=plt.cm.YlOrBr, alpha=0.8)

    fig = plt.gcf()
    fig.set_size_inches(8, 8)
    ax.set_frame_on(False)

    ax.set_xticklabels([1, 2, 3, 4], minor=False)
    ax.grid(False)
    ax = plt.gca()

    fig.savefig('report.png')
项目:speccer    作者:bensimner    | 项目源码 | 文件源码
def make_new_pie_from_callers(callers, call_name=None):
    # plot the stats
    fig, ax = plt.subplots()

    if call_name:
        ax.set_title('Breakdown of {} callees'.format(call_name))

    labels, sizes, callbacks = make_pie_from_callers(callers)
    wedges, _ = ax.pie(sizes, labels=labels)

    for w in wedges:
        w.set_picker(True)

    def onclick(evt):
        l = evt.artist.get_label()
        cb = callbacks[l]
        if cb:
            if l == 'other':
                l = '{}/other'.format(call_name)
            make_new_pie_from_callers(cb, call_name=l)


    fig.canvas.mpl_connect('pick_event', onclick)
    ax.axis('equal')
    plt.show()
项目:chash    作者:luhsra    | 项目源码 | 文件源码
def plot_build_time_composition_graph(parseTimes, hashTimes, compileTimes, diffToBuildTime): # times in s
    fig, ax = plt.subplots()

    ax.stackplot(np.arange(1, len(parseTimes)+1), # x axis
#                 [parseTimes, hashTimes, compileTimes, diffToBuildTime],
                  [[i/60 for i in parseTimes], [i/60 for i in hashTimes], [i/60 for i in compileTimes], [i/60 for i in diffToBuildTime]],
                 colors=[parseColor,hashColor,compileColor,remainColor], edgecolor='none')
    plt.xlim(1,len(parseTimes))
    plt.xlabel('commits')
    plt.ylabel('time [min]')
    lgd = ax.legend([mpatches.Patch(color=remainColor),
                     mpatches.Patch(color=compileColor),
                     mpatches.Patch(color=hashColor),
                     mpatches.Patch(color=parseColor)],
                    ['remaining build time','compile time', 'hash time', 'parse time'],
                    loc='center left', bbox_to_anchor=(1, 0.5))
    fig.savefig(abs_path(BUILD_TIME_COMPOSITION_FILENAME), bbox_extra_artists=(lgd,), bbox_inches='tight')
    print_avg(parseTimes, 'parse')
    print_avg(hashTimes, 'hash')
    print_avg(compileTimes, 'compile')
    print_avg(diffToBuildTime, 'remainder')
项目:chash    作者:luhsra    | 项目源码 | 文件源码
def plotTimeMultiHistogram(parseTimes, hashTimes, compileTimes, filename): # times in ms
    bins = np.linspace(0, 5000, 50)
    data = np.vstack([parseTimes, hashTimes, compileTimes]).T
    fig, ax = plt.subplots()
    plt.hist(data, bins, alpha=0.7, label=['parsing', 'hashing', 'compiling'], color=[parseColor, hashColor, compileColor])
    plt.legend(loc='upper right')
    plt.xlabel('time [ms]')
    plt.ylabel('#files')
    fig.savefig(filename)

    fig, ax = plt.subplots()
    boxplot_data = [[i/1000 for i in parseTimes], [i/1000 for i in hashTimes], [i/1000 for i in compileTimes]] # times to s
    plt.boxplot(boxplot_data, 0, 'rs', 0, [5, 95])
    plt.xlabel('time [s]')
    plt.yticks([1, 2, 3], ['parsing', 'hashing', 'compiling'])
    #lgd = ax.legend(loc='center left', bbox_to_anchor=(1, 0.5)) # legend on the right
    fig.savefig(filename[:-4] + '_boxplots' + GRAPH_EXTENSION)
项目:chash    作者:luhsra    | 项目源码 | 文件源码
def plot_build_time_composition_graph(parse_times, hash_times, compile_times, diff_to_build_time): # times in ns
    fig, ax = plt.subplots()
#[i/1e6 for i in parse_times],
    ax.stackplot(np.arange(1, len(parse_times)+1), # x axis
                 [[i/1e6 for i in parse_times], [i/1e6 for i in hash_times],[i/1e6 for i in compile_times], # ns to ms
                #diff_to_build_time
                ], colors=[parse_color,hash_color,compile_color,
                 #   remain_color
                ], edgecolor='none')
    plt.xlim(1,len(parse_times))
    plt.xlabel('commits')
    plt.ylabel('time [ms]')
    ax.set_yscale('log')
    lgd = ax.legend([#mpatches.Patch(color=remain_color),
                     mpatches.Patch(color=compile_color),
                     mpatches.Patch(color=hash_color),
                     mpatches.Patch(color=parse_color)],
                    [#'remaining build time',
                    'compile time', 'hash time', 'parse time'],
                    loc='center left', bbox_to_anchor=(1, 0.5))
    fig.savefig(abs_path(BUILD_TIME_FILENAME), bbox_extra_artists=(lgd,), bbox_inches='tight')



################################################################################
项目:demcoreg    作者:dshean    | 项目源码 | 文件源码
def genplot(x, y, fit, xdata=None, ydata=None, maxpts=10000):
    bin_range = (0, 360)
    a = (np.arange(*bin_range))
    f_a = nuth_func(a, fit[0], fit[1], fit[2])
    nuth_func_str = r'$y=%0.2f*cos(%0.2f-x)+%0.2f$' % tuple(fit)
    if xdata.size > maxpts:
        import random
        idx = random.sample(list(range(xdata.size)), 10000)
    else:
        idx = np.arange(xdata.size)
    f, ax = plt.subplots()
    ax.set_xlabel('Aspect (deg)')
    ax.set_ylabel('dh/tan(slope) (m)')
    ax.plot(xdata[idx], ydata[idx], 'k.', label='Orig pixels')
    ax.plot(x, y, 'ro', label='Bin median')
    ax.axhline(color='k')
    ax.plot(a, f_a, 'b', label=nuth_func_str)
    ax.set_xlim(*bin_range)
    pad = 0.2 * np.max([np.abs(y.min()), np.abs(y.max())])
    ax.set_ylim(y.min() - pad, y.max() + pad)
    ax.legend(prop={'size':8})
    return f 

#Function copied from from openPIV pyprocess
项目:chainer-visualization    作者:hvy    | 项目源码 | 文件源码
def save_ims(filename, ims, dpi=100, scale=0.5):
    n, c, h, w = ims.shape

    rows = int(math.ceil(math.sqrt(n)))
    cols = int(round(math.sqrt(n)))

    fig, axes = plt.subplots(rows, cols, figsize=(w*cols/dpi*scale, h*rows/dpi*scale), dpi=dpi)

    for i, ax in enumerate(axes.flat):
        if i < n:
            ax.imshow(ims[i].transpose((1, 2, 0)))
        ax.set_xticks([])
        ax.set_yticks([])
        ax.axis('off')

    plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0.1, hspace=0.1)
    plt.savefig(filename, dpi=dpi, bbox_inces='tight', transparent=True)
    plt.clf()
    plt.close()
项目:fingerprint-securedrop    作者:freedomofpress    | 项目源码 | 文件源码
def plot_feature_importances(feature_names, feature_importances, N=30):
    importances = list(zip(feature_names, list(feature_importances)))
    importances = pd.DataFrame(importances, columns=["Feature", "Importance"])
    importances = importances.set_index("Feature")

    # Sort by the absolute value of the importance of the feature
    importances["sort"] = abs(importances["Importance"])
    importances = importances.sort(columns="sort", ascending=False).drop("sort", axis=1)
    importances = importances[0:N]

    # Show the most important positive feature at the top of the graph
    importances = importances.sort(columns="Importance", ascending=True)

    with plt.style.context(('ggplot')):
        fig, ax = plt.subplots(figsize=(16,12))
        ax.tick_params(labelsize=16)
        importances.plot(kind="barh", legend=False, ax=ax)
        ax.set_frame_on(False)
        ax.set_xlabel("Relative importance", fontsize=20)
        ax.set_ylabel("Feature name", fontsize=20)
    plt.tight_layout()
    plt.title("Most important features for attack", fontsize=20).set_position([.5, 0.99])
    return fig
项目:linkedin_recommend    作者:duggalr2    | 项目源码 | 文件源码
def pieGraph(data_count):
    """
    Graph's a pie graph of the data with count values; Only includes data that appears more than once!
    Parameter: -data_count: dict
    """
    names, count = [], []
    for val, key in data_count.items():
        if key > 1:
            names.append(val)
            count.append(key)

    fig1, ax1 = plt.subplots()
    ax1.pie(count, labels=names, autopct='%1.1f%%', shadow=True, startangle=90)
    ax1.axis('equal')  # Equal aspect ratio ensures that pie is drawn as a circle.
    # plt.tight_layout()
    plt.show()
项目:linkedin_recommend    作者:duggalr2    | 项目源码 | 文件源码
def pie_graph(data_count):
    """
    Graph's a pie graph of the data with count values (only shows schools that appear more than once)
    Parameter: -data_count: dict
    """
    names, count = [], []
    for val, key in data_count.items():
        if key > 1:
            names.append(val)
            count.append(key)

    fig1, ax1 = plt.subplots()
    ax1.pie(count, labels=names, autopct='%1.1f%%', shadow=True, startangle=90)
    ax1.axis('equal')  # Equal aspect ratio ensures that pie is drawn as a circle.
    # plt.tight_layout()
    plt.show()
项目:linkedin_recommend    作者:duggalr2    | 项目源码 | 文件源码
def barGraph(data_count):

    names, count_in = [], []
    data_count = sorted(data_count.items(), key=operator.itemgetter(1), reverse=True)
    for i in data_count:
        names.append(i[0])
        count_in.append(i[-1])

    plt.rcdefaults()
    fig, ax = plt.subplots()
    y_pos = np.arange(len(names))
    ax.barh(y_pos, count_in, align='center',
            color='green', ecolor='black')
    ax.set_yticks(y_pos)
    ax.set_yticklabels(names)
    ax.invert_yaxis()  # labels read top-to-bottom
    ax.set_xlabel('Categories')
    ax.set_title('# of job titles in each category')
    plt.show()
项目:BISIP    作者:clberube    | 项目源码 | 文件源码
def compare_fits():
    fig, ax = plt.subplots(1, 2, figsize=(10,4))
    plot_data(sol[0], ax)
    plot_fit(sol[0], ax)
    plot_mean_fit(sol, ax)
    for s in sol:
        plot_fit(s, ax)
#    ax[0].set_title("Model 4 ($c_1 = 0.4$, $m_1 = 0.4$)")
#    ax[1].set_title("Model 4 ($c_1 = 0.4$, $m_1 = 0.4$)")
    ax[0].set_title("Sample B")
    ax[1].set_title("Sample B")
    ax[0].set_ylim([0,1.1])
    ax[1].set_ylim([-0.01,None])
    fig.tight_layout()
#    fig.savefig('%d_Attempts_Adaptive_%s'%(len(sol), adapt))
#    fig.savefig('%d_Attempts-SampleB_ColeCole_Adaptive_False'%(len(sol)))
项目:dsb3    作者:EliasVansteenkiste    | 项目源码 | 文件源码
def plot_slice_3d_3axis(input, pid, img_dir=None, idx=None):
    # to convert cuda arrays to numpy array
    input = np.asarray(input)

    fig, ax = plt.subplots(2, 2, figsize=[8, 8])
    fig.canvas.set_window_title(pid)
    ax[0, 0].imshow(input[idx[0], :, :], cmap=plt.cm.gray)
    ax[1, 0].imshow(input[:, idx[1], :], cmap=plt.cm.gray)
    ax[0, 1].imshow(input[:, :, idx[2]], cmap=plt.cm.gray)

    if img_dir is not None:
        fig.savefig(img_dir + '/%s.png' % (pid), bbox_inches='tight')
    else:
        plt.show()
    fig.clf()
    plt.close('all')
项目:dsb3    作者:EliasVansteenkiste    | 项目源码 | 文件源码
def plot_all_slices(input, pid, img_dir=None):
    # to convert cuda arrays to numpy array
    input = np.asarray(input)

    for idx in range(0, input.shape[0]-3, 4):
        fig, ax = plt.subplots(2, 2, figsize=[8, 8])
        fig.canvas.set_window_title(pid)
        ax[0, 0].imshow(input[idx, :, :], cmap=plt.cm.gray)
        ax[1, 0].imshow(input[idx+1, :, :], cmap=plt.cm.gray)
        ax[0, 1].imshow(input[idx+2, :, :], cmap=plt.cm.gray)
        ax[1, 1].imshow(input[idx+3, :, :], cmap=plt.cm.gray)

        if img_dir is not None:
            fig.savefig(img_dir + '_' + str(pid) + '_' + str(idx) + '.png' , bbox_inches='tight')
        else:
            plt.show()
        fig.clf()
        plt.close('all')
项目:dsb3    作者:EliasVansteenkiste    | 项目源码 | 文件源码
def plot_all_slices(ct_scan, mask, pid, img_dir=None):
    # to convert cuda arrays to numpy array
    ct_scan = np.asarray(ct_scan)
    mask = np.asarray(mask)

    for idx in range(0, mask.shape[0]-3, 2):
        fig, ax = plt.subplots(2, 2, figsize=[8, 8])
        fig.canvas.set_window_title(pid)
        ax[0, 0].imshow(mask[idx, :, :], cmap=plt.cm.gray)
        ax[1, 0].imshow(ct_scan[idx+1, :, :], cmap=plt.cm.gray)
        ax[0, 1].imshow(mask[idx+2, :, :], cmap=plt.cm.gray)
        ax[1, 1].imshow(ct_scan[idx+3, :, :], cmap=plt.cm.gray)

        if img_dir is not None:
            fig.savefig(img_dir + '_' + str(pid) + '_' + str(idx) + '.png' , bbox_inches='tight')
        else:
            plt.show()
        fig.clf()
        plt.close('all')
项目:dsb3    作者:EliasVansteenkiste    | 项目源码 | 文件源码
def plot_4_slices(input, pid, img_dir=None, idx=None):
    # to convert cuda arrays to numpy array
    input = np.asarray(input)

    fig, ax = plt.subplots(2, 2, figsize=[8, 8])
    fig.canvas.set_window_title(pid)
    ax[0, 0].imshow(input[idx[0], :, :], cmap=plt.cm.gray)
    ax[1, 0].imshow(input[:, idx[1], :], cmap=plt.cm.gray)
    ax[0, 1].imshow(input[:, :, idx[2]], cmap=plt.cm.gray)
    ax[1, 1].imshow(input[:, :, idx[2]], cmap=plt.cm.gray)

    if img_dir is not None:
        fig.savefig(img_dir + '/%s.png' % (pid), bbox_inches='tight')
    else:
        plt.show()
    fig.clf()
    plt.close('all')
项目:structured-output-ae    作者:sbelharbi    | 项目源码 | 文件源码
def plot_cdf_model_and_meansh(self, cdfs, tag, cdf0_1s, aucs, bx, dx):
        plt.close("all")
        x = np.arange(0, bx, dx)
        fig, ax = plt.subplots(nrows=1, ncols=1)
        ax.plot(x, cdfs[0], label="CDF model")
        ax.plot(x, cdfs[1], label="CDF mean shape")
        ax.grid(True)
        plt.xlabel("NRMSE")
        plt.ylabel("Data proportion")
        plt.legend(loc=4, prop={'size': 8}, fancybox=True, shadow=True)
        plt.title(
            "CDF curve: " + tag + ". Model: CDF0.1: " +
            str(prec2 % cdf0_1s[0]) + " . AUC:" + str(prec2 % aucs[0]) +
            ".\n" + ". MSh: CDF0.1: " +
            str(prec2 % cdf0_1s[1]) + " . AUC:" + str(prec2 % aucs[1]) + ".\n")
        return fig
项目:droppy    作者:BV-DR    | 项目源码 | 文件源码
def mapFunction( x , y , func , ax = None, arrayInput = False, n = 10, title = None, **kwargs ) :
   """
      Plot function on a regular grid
        x : 1d array
        y : 1d array
        func : function to map
        arrayInput : False if func(x,y) , True if func( [x,y] )
   """

   if ax is None :
      fig , ax = plt.subplots()

   X , Y = np.meshgrid( x , y )

   if not arrayInput :
      Z = func( X.flatten() , Y.flatten() ).reshape(X.shape)
   else :
      Z = func( np.stack( [ X.flatten() , Y.flatten() ]) )

   ax.contourf( X , Y , Z , n , **kwargs)

   if title is not None : ax.set_title(title)

   return ax
项目:prysm    作者:brandondube    | 项目源码 | 文件源码
def share_fig_ax(fig=None, ax=None, numax=1, sharex=False, sharey=False):
    ''' Reurns the given figure and/or axis if given one.  If they are None, creates a new fig/ax

    Args:
        fig (`pyplot.figure`): figure.

        ax (`pyplot.axis`): axis or array of axes.

        numax (`int`): number of axes in the desired figure.
                     1 for most plots, 3 for plot_fourier_chain.

    Returns:
        pyplot.figure:  A figure object.

        pyplot.axis:  An axis object.

    '''
    if fig is None and ax is None:
        fig, ax = plt.subplots(nrows=1, ncols=numax, sharex=sharex, sharey=sharey)
    elif fig is None:
        fig = ax.get_figure()
    elif ax is None:
        ax = fig.gca()

    return fig, ax
项目:autonomio    作者:autonomio    | 项目源码 | 文件源码
def accuracy(data):

    fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True)

    ax1.plot(data['train_acc'])
    ax1.plot(data['test_acc'])
    ax2.plot(data['train_loss'])
    ax2.plot(data['test_loss'])

    ax1.set_title('accuracy')
    ax1.set_xlabel('epoch')

    ax2.set_title('loss')
    ax2.set_xlabel('epoch')

    plt.ylim((0, 1))

    fig.set_size_inches(20, 5)
    fig.savefig('train.png', dpi=300, bbox_inches='tight')
    fig.show()
项目:DeblurGAN    作者:KupynOrest    | 项目源码 | 文件源码
def __plot_canvas(self, show, save):
        if len(self.PSFs) == 0:
            raise Exception("Please run fit() method first.")
        else:
            plt.close()
            fig, axes = plt.subplots(1, self.PSFnumber, figsize=(10, 10))
            for i in range(self.PSFnumber):
                axes[i].imshow(self.PSFs[i], cmap='gray')
            if show and save:
                if self.path_to_save is None:
                    raise Exception('Please create Trajectory instance with path_to_save')
                plt.savefig(self.path_to_save)
                plt.show()
            elif save:
                if self.path_to_save is None:
                    raise Exception('Please create Trajectory instance with path_to_save')
                plt.savefig(self.path_to_save)
            elif show:
                plt.show()
项目:DeblurGAN    作者:KupynOrest    | 项目源码 | 文件源码
def __plot_canvas(self, show, save):
        if len(self.result) == 0:
            raise Exception('Please run blur_image() method first.')
        else:
            plt.close()
            plt.axis('off')
            fig, axes = plt.subplots(1, len(self.result), figsize=(10, 10))
            if len(self.result) > 1:
                for i in range(len(self.result)):
                        axes[i].imshow(self.result[i])
            else:
                plt.axis('off')

                plt.imshow(self.result[0])
            if show and save:
                if self.path_to_save is None:
                    raise Exception('Please create Trajectory instance with path_to_save')
                cv2.imwrite(os.path.join(self.path_to_save, self.image_path.split('/')[-1]), self.result[0] * 255)
                plt.show()
            elif save:
                if self.path_to_save is None:
                    raise Exception('Please create Trajectory instance with path_to_save')
                cv2.imwrite(os.path.join(self.path_to_save, self.image_path.split('/')[-1]), self.result[0] * 255)
            elif show:
                plt.show()
项目:acdc_segmenter    作者:baumgach    | 项目源码 | 文件源码
def boxplot_metrics(df, eval_dir):
    """
    Create summary boxplots of all geometric measures.

    :param df:
    :param eval_dir:
    :return:
    """

    boxplots_file = os.path.join(eval_dir, 'boxplots.eps')

    fig, axes = plt.subplots(3, 1)
    fig.set_figheight(14)
    fig.set_figwidth(7)

    sns.boxplot(x='struc', y='dice', hue='phase', data=df, palette="PRGn", ax=axes[0])
    sns.boxplot(x='struc', y='hd', hue='phase', data=df, palette="PRGn", ax=axes[1])
    sns.boxplot(x='struc', y='assd', hue='phase', data=df, palette="PRGn", ax=axes[2])

    plt.savefig(boxplots_file)
    plt.close()

    return 0
项目:em_examples    作者:geoscixyz    | 项目源码 | 文件源码
def WaveVelandSkindWidget(epsr, sigma):
    frequency = np.logspace(1, 9, 61)
    vel, skind = WaveVelSkind(frequency, epsr, 10**sigma)
    figure, ax = plt.subplots(1, 2, figsize = (10, 4))
    ax[0].loglog(frequency, vel, 'b', lw=3)
    ax[1].loglog(frequency, skind, 'r', lw=3)
    ax[0].set_ylim(1e6, 1e9)
    ax[1].set_ylim(1e-1, 1e7)
    ax[0].set_xlabel('Frequency (Hz)')
    ax[0].set_ylabel('Velocity (m/s)')
    ax[1].set_xlabel('Frequency (Hz)')
    ax[1].set_ylabel('Skin Depth (m)')
    ax[0].grid(True)
    ax[1].grid(True)

    plt.show()
    return
项目:DenseNet    作者:kevinzakka    | 项目源码 | 文件源码
def plot_images(images, cls_true, name):

    assert len(images) == len(cls_true) == 9

    # Create figure with sub-plots.
    fig, axes = plt.subplots(3, 3)

    for i, ax in enumerate(axes.flat):
        # plot the image
        ax.imshow(images[i, :, :, :], interpolation='spline16')

        # get its equivalent class name
        if name == 'cifar10':
            cls_true_name = cifar10_label_names[cls_true[i]]
        else:
            cls_true_name = cifar100_label_names[cls_true[i]]

        xlabel = "{0} ({1})".format(cls_true_name, cls_true[i])  
        ax.set_xlabel(xlabel)
        ax.set_xticks([])
        ax.set_yticks([])

    plt.show()
项目:mitre    作者:gerberlab    | 项目源码 | 文件源码
def minimal_show_rule(dataset,variable_name,window_start,window_end,average=None,slope=None):
    """
    Show how a PrimitiveRule applies to a dataset.

    Like show_rule, but with less labeling.

    """
    f, ax = plt.subplots()
    ax.axvspan(xmin=window_start,xmax=window_end,color='k',alpha=0.3)
    _alternate_show(ax, dataset, variable_name)
    ymin, ymax = ax.get_ylim()
    y_center = 0.5*(ymin+ymax)
    window_center = 0.5*(window_start+window_end)
    threshold_x = np.linspace(window_start, window_end)
    if average is not None:
        ax.plot(threshold_x,average*np.ones(len(threshold_x)),'r')
    else:
        ax.plot(threshold_x,y_center+slope*(threshold_x - window_center),'r')
    return (f, ax)
项目:Supply-demand-forecasting    作者:LevinJ    | 项目源码 | 文件源码
def disp_district_by_district_type(self):
        df = self.get_district_type_table()
        dt_list = self.get_district_type_list()
        size = df.shape[0]
        col_len = 8
        row_len = 8

        _, axarr = plt.subplots(row_len, col_len, sharex=True, sharey=True)
        for row in range(row_len):
            for col in range(col_len):
                index = row * col_len + col
                if index >= size:
                    break
                item = df.iloc[index]
                x_locations = np.arange(len(dt_list))
                axarr[row, col].bar(x_locations, item[dt_list])
                axarr[row, col].set_xlabel('start_district_' + str(item['start_district_id']))
        return
项目:Supply-demand-forecasting    作者:LevinJ    | 项目源码 | 文件源码
def disp_gap_bydate(self):
        gaps_mean = self.gapdf.groupby('time_date')['gap'].mean()
        gaps_mean.plot(kind='bar')
        plt.ylabel('Mean of gap')
        plt.title('Date/Gap Correlation')
#         for i in gaps_mean.index:
#             plt.plot([i,i], [0, gaps_mean[i]], 'k-')
        plt.show()
        return

#     def drawGapDistribution(self):
#         self.gapdf[self.gapdf['gapdf'] < 10]['gapdf'].hist(bins=50)
# #         sns.distplot(self.gapdf['gapdf']);
# #         sns.distplot(self.gapdf['gapdf'], hist=True, kde=False, rug=False)
# #         plt.hist(self.gapdf['gapdf'])
#         plt.show()
#         return
#     def drawGapCorrelation(self):
#         _, (ax1, ax2) = plt.subplots(nrows=2, ncols=1)
#         res = self.gapdf.groupby('start_district_id')['gapdf'].sum()
#         ax1.bar(res.index, res.values)
#         res = self.gapdf.groupby('time_slotid')['gapdf'].sum()
#         ax2.bar(res.index.map(lambda x: x[11:]), res.values)
#         plt.show()
#         return
项目:discretize    作者:simpeg    | 项目源码 | 文件源码
def run(plotIt=True):
    sz = [16, 16]
    tM = discretize.TensorMesh(sz)
    qM = discretize.TreeMesh(sz)

    def refine(cell):
        if np.sqrt(((np.r_[cell.center]-0.5)**2).sum()) < 0.4:
            return 4
        return 3

    qM.refine(refine)
    rM = discretize.CurvilinearMesh(discretize.utils.exampleLrmGrid(sz, 'rotate'))

    if not plotIt:
        return
    fig, axes = plt.subplots(1, 3, figsize=(14, 5))
    opts = {}
    tM.plotGrid(ax=axes[0], **opts)
    axes[0].set_title('TensorMesh')
    qM.plotGrid(ax=axes[1], **opts)
    axes[1].set_title('TreeMesh')
    rM.plotGrid(ax=axes[2], **opts)
    axes[2].set_title('CurvilinearMesh')
项目:SelfDrivingCar    作者:aguijarro    | 项目源码 | 文件源码
def draw_images(img, undistorted, title, cmap):
    f, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 9))
    f.tight_layout()
    ax1.imshow(img)
    ax1.set_title('Original Image', fontsize=50)
    if cmap is not None:
        ax2.imshow(undistorted, cmap=cmap)
    else:
        ax2.imshow(undistorted)
    ax2.set_title(title, fontsize=50)
    plt.subplots_adjust(left=0., right=1, top=0.9, bottom=0.)
    plt.show()


# TODO: Write a function that takes an image, object points, and image points
# performs the camera calibration, image distortion correction and
# returns the undistorted image
项目:hsmm4acc    作者:wadpac    | 项目源码 | 文件源码
def plot_boxplots(data, hidden_states):
    """
    Plot boxplots for all variables in the dataset, per state

    Parameters
    ------
    data : pandas DataFrame
        Data to plot
    hidden_states: iteretable
        the hidden states corresponding to the timesteps
    """
    column_names = data.columns
    figs, axes = plt.subplots(len(column_names), figsize=(15, 15))
    for j, var in enumerate(column_names):
        axes[j].set_title(var)
        vals = data[var]
        data_to_plot = []
        labels = []
        for i in set(hidden_states):
            mask = hidden_states == i
            if (sum(mask) > 0):
                labels.append(str(i))
                values = np.array(vals[mask])
                data_to_plot.append(values)
        axes[j].boxplot(data_to_plot, sym='', labels=labels)
项目:hsmm4acc    作者:wadpac    | 项目源码 | 文件源码
def plot_perstate(data, hidden_states):
    '''
    Make, for each state, a plot of the data

    Parameters
    ----------
    data : pandas DataFrame
        Data to plot
    hidden_states: iteretable
        the hidden states corresponding to the timesteps
    '''
    num_states = max(hidden_states) + 1
    fig, axs = plt.subplots(
        num_states, sharex=True, sharey=True, figsize=(15, 15))
    colours = plt.cm.rainbow(np.linspace(0, 1, num_states))
    for i, (ax, colour) in enumerate(zip(axs, colours)):
        # Use fancy indexing to plot data in each state.
        data_to_plot = data.copy()
        data_to_plot[hidden_states != i] = 0
        data_to_plot.plot(ax=ax, legend=False)
        ax.set_title("{0}th hidden state".format(i))
        ax.grid(True)
    plt.legend(bbox_to_anchor=(0, -1, 1, 1), loc='lower center')
    plt.show()
项目:trend_ml_toolkit_xgboost    作者:raymon-tian    | 项目源码 | 文件源码
def fea_plot(xg_model, feature, label, type = 'weight', max_num_features = None):
    fig, AX = plt.subplots(nrows=1, ncols=2)
    xgb.plot_importance(xg_model, xlabel=type, importance_type='weight', ax=AX[0], max_num_features=max_num_features)

    fscore = xg_model.get_score(importance_type=type)
    fscore = sorted(fscore.items(), key=itemgetter(1), reverse=True) # sort scores
    fea_index = get_fea_index(fscore, max_num_features)
    feature = feature[:, fea_index]
    dimension = len(fea_index)
    X = range(1, dimension+1)
    Yp = np.mean(feature[np.where(label==1)[0]], axis=0)
    Yn = np.mean(feature[np.where(label!=1)[0]], axis=0)
    for i in range(0, dimension):
        param = np.fmax(Yp[i], Yn[i])
        Yp[i] /= param
        Yn[i] /= param
    p1 = AX[1].bar(X, +Yp, facecolor='#ff9999', edgecolor='white')
    p2 = AX[1].bar(X, -Yn, facecolor='#9999ff', edgecolor='white')
    AX[1].legend((p1,p2), ('Malware', 'Normal'))
    AX[1].set_title('Comparison of selected features by their means')
    AX[1].set_xlabel('Feature Index')
    AX[1].set_ylabel('Mean Value')
    AX[1].set_ylim(-1.1, 1.1)
    plt.xticks(X, fea_index+1, rotation=80)
    plt.suptitle('Feature Selection results')
项目:psyplot    作者:Chilipp    | 项目源码 | 文件源码
def test_export_03_append(self):
        """Append to a pdf file"""
        import tempfile
        self._register_export_plotter()
        fig1, ax1 = plt.subplots(1, 2)
        fig2, ax2 = plt.subplots()
        axes = list(ax1) + [ax2]
        sp = psy.plot.test_plotter(bt.get_file('test-t2m-u-v.nc'),
                                   name='t2m', time=[1, 2, 3], z=0, y=0,
                                   ax=axes)
        self.assertEqual(len(sp), 3, msg=sp)

        fname = tempfile.NamedTemporaryFile(
            suffix='.pdf', prefix='psyplot_').name
        self._created_files.add(fname)

        pdf = sp.export(fname, close_pdf=False)

        self.assertEqual(pdf.get_pagecount(), 2)

        sp.export(pdf)

        self.assertEqual(pdf.get_pagecount(), 4)

        pdf.close()
项目:psyplot    作者:Chilipp    | 项目源码 | 文件源码
def test_filter_7_fig(self):
        """Test the filtering of the ArrayList"""
        import matplotlib.pyplot as plt
        from psyplot.plotter import Plotter
        ds = self._filter_test_ds
        l = self.list_class.from_dataset(ds, ydim=[0, 1], name='v0')
        figs = [0, 0]
        axes = [0, 0]
        figs[0], axes[0] = plt.subplots()
        figs[1], axes[1] = plt.subplots()
        for i, arr in enumerate(l):
            Plotter(arr, ax=axes[i])
        # mix criteria
        self.assertEqual(
            [arr.psy.arr_name for arr in l(fig=figs[0])],
            [l[0].psy.arr_name])
        self.assertEqual(
            [arr.psy.arr_name for arr in l(fig=figs[1])],
            [l[1].psy.arr_name])
项目:planetplanet    作者:rodluger    | 项目源码 | 文件源码
def plot(self, ax=None):
        '''
        Plots the filter throughput curve.

        :param ax: An axis instance
        :type ax: :py:obj:`axis`

        '''
        import matplotlib.pyplot as plt
        if ax is None:
            fig, axi = plt.subplots(figsize=(10,6))
            axi.set_xlabel("Wavelength")
            axi.set_ylabel("Throughput")
        else:
            axi = ax

        axi.plot(self.wl, self.throughput)
        #axi.errorbar(self.eff_wl, np.max(self.throughput)/2,
        # xerr=self.eff_dwl/2, fmt="o", c="k", ms=5)

        plt.show()
项目:ward-metrics    作者:phev8    | 项目源码 | 文件源码
def plot_twoset_metrics(results, startangle=120):
    fig1, axarr = plt.subplots(1, 2)

    # plot positive rates:
    labels_1 = ["tpr", "us", "ue", "fr", "dr"]
    values_1 = [
        results["tpr"],
        results["us"],
        results["ue"],
        results["fr"],
        results["dr"]
    ]

    axarr[0].pie(values_1, labels=labels_1, autopct='%1.0f%%', startangle=startangle)
    axarr[0].axis('equal')  # Equal aspect ratio ensures that pie is drawn as a circle.
    # TODO: add title

    # plot negative rates:
    labels_2 = ["1-fpr", "os", "oe", "mr", "ir"]
    values_2 = [
        1-results["fpr"],
        results["os"],
        results["oe"],
        results["mr"],
        results["ir"]
    ]

    axarr[1].pie(values_2, labels=labels_2, autopct='%1.0f%%', startangle=startangle)
    axarr[1].axis('equal')  # Equal aspect ratio ensures that pie is drawn as a circle.
    # TODO: add title

    plt.show()
项目:ward-metrics    作者:phev8    | 项目源码 | 文件源码
def plot_segment_counts(results):
    # TODO: add title
    labels = results.keys()
    values = []
    for label in labels:
        values.append(results[label])

    #explode = (0, 0.1, 0, 0)  # only "explode" the 2nd slice (i.e. 'Hogs')

    total = sum(values)

    fig1, ax1 = plt.subplots()
    ax1.pie(values, labels=labels, autopct=lambda p: '{:.0f}'.format(p * total / 100), startangle=90)
    ax1.axis('equal')  # Equal aspect ratio ensures that pie is drawn as a circle.
    plt.show()
项目:soccerstan    作者:Torvaney    | 项目源码 | 文件源码
def plot_team_parameter(data, title, alpha=0.05, axes_colour='dimgray'):
    """ Plot 2-dimensional parameters (i.e. a parameter for each team). """
    fig, ax = plt.subplots(figsize=(8, 6))

    upper = 1 - (alpha / 2)
    lower = 0 + (alpha / 2)

    # Sort by median values
    ordered_teams = data.median().sort_values().keys()

    for i, team in enumerate(ordered_teams):
        x_mean = np.median(data[team])
        x_lower = np.percentile(data[team], lower * 100)
        x_upper = np.percentile(data[team], upper * 100)

        ax.scatter(x_mean, i, alpha=1, color='black', s=25)
        ax.hlines(i, x_lower, x_upper, color='black')

    ax.set_ylim([-1, len(ordered_teams)])
    ax.set_yticks(list(range(len(ordered_teams))))
    ax.set_yticklabels(list(ordered_teams))

    # Add title
    fig.suptitle(title, ha='left', x=0.125, fontsize=18, color='k')

    # Change axes colour
    ax.spines["bottom"].set_color(axes_colour)
    ax.spines["left"].set_color(axes_colour)
    ax.tick_params(colors=axes_colour)

    # Remove top and bottom spines
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.spines["left"].set_visible(False)

    return fig
项目:genomedisco    作者:kundajelab    | 项目源码 | 文件源码
def plot_dds(dd_list,dd_names,out,approximation=10000):
    assert len(dd_list)==len(dd_names)

    rcParams['figure.figsize'] = 7,7
    rcParams['font.size']= 30
    rcParams['xtick.labelsize'] = 20
    rcParams['ytick.labelsize'] = 20
    fig, plots = plt.subplots(nrows=1, ncols=1)
    fig.set_size_inches(7, 7)
    colors=['red','blue']
    for dd_idx in range(len(dd_names)):
        dd_name=dd_names[dd_idx]
        dd=list(dd_list[dd_idx].values())
        x=list(dd_list[dd_idx].keys())
        sorted_x=np.argsort(np.array(x))
        x_plot=[]
        dd_plot=[]
        x_idx=0
        while x_idx<len(x):
            x_plot.append(x[sorted_x[x_idx]]*approximation)
            dd_plot.append(dd[sorted_x[x_idx]])
            x_idx+=1
        plots.plot(x_plot[1:],dd_plot[1:],c=colors[dd_idx],label=dd_names[dd_idx])
        plots.set_yscale('log',basey=10)
        plots.set_xscale('log',basex=10)
        plots.set_xlabel('distance (bp)')
        plots.set_ylabel('contact probability')
    plots.legend(loc=3,fontsize=20)
    #fig.tight_layout()
    adj=0.2
    plt.gcf().subplots_adjust(bottom=adj)
    plt.gcf().subplots_adjust(left=adj)

    plt.savefig(out+'.png')
项目:deep-learning    作者:ljanyst    | 项目源码 | 文件源码
def gen_sample_summary(samples):
    fig, axes = plt.subplots(figsize=(5,3), nrows=3, ncols=5,
                             sharey=True, sharex=True)
    plt.subplots_adjust(wspace=0, hspace=0)

    for ax, img in zip(axes.flatten(), samples):
        ax.axis('off')
        img = ((img - img.min())*255 / (img.max() - img.min())).astype(np.uint8)
        ax.set_adjustable('box-forced')
        im = ax.imshow(img, aspect='equal')

    arr = figure_to_numpy(fig)
    del(fig)
    return arr
项目:deep-learning    作者:ljanyst    | 项目源码 | 文件源码
def gen_sample_summary(samples):
    fig, axes = plt.subplots(figsize=(5,3), nrows=3, ncols=5, \
                             sharey=True, sharex=True)
    for ax, img in zip(axes.flatten(), samples):
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        im = ax.imshow(img.reshape((28,28)), cmap='Greys_r')

    return figure_to_numpy(fig)
项目:guided-filter    作者:lisabug    | 项目源码 | 文件源码
def plot_multiple(imgs, main_title='', titles=''):
    num_img = len(imgs)
    rows = (num_img + 1) / 2
    plt.figure()
    plt.title(main_title)
    f, axarr = plt.subplots(rows, 2)
    for i, (img, title) in enumerate(zip(imgs, titles)):
        axarr[i/2, i%2].imshow(img.astype(np.uint8), cmap='gray')
        axarr[i/2, i%2].set_title(title)
    plt.waitforbuttonpress()
项目:HandDetection    作者:YunqiuXu    | 项目源码 | 文件源码
def vis_detections(im, class_name, dets, thresh=0.5):
    """Draw detected bounding boxes."""
    inds = np.where(dets[:, -1] >= thresh)[0]
    if len(inds) == 0:
        return

    im = im[:, :, (2, 1, 0)]
    #fig, ax = plt.subplots(figsize=(12, 12))
    ax.imshow(im, aspect='equal')
    for i in inds:
        bbox = dets[i, :4]
        score = dets[i, -1]

        ax.add_patch(
            plt.Rectangle((bbox[0], bbox[1]),
                          bbox[2] - bbox[0],
                          bbox[3] - bbox[1], fill=False,
                          edgecolor='red', linewidth=3.5)
            )
        ax.text(bbox[0], bbox[1] - 2,
                '{:s} {:.3f}'.format(class_name, score),
                bbox=dict(facecolor='blue', alpha=0.5),
                fontsize=14, color='white')

    ax.set_title(('{} detections with '
                  'p({} | box) >= {:.1f}').format(class_name, class_name,
                                                  thresh),
                  fontsize=14)
    plt.axis('off')
    plt.tight_layout()
    plt.draw()