Python matplotlib.gridspec 模块,GridSpec() 实例源码

我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用matplotlib.gridspec.GridSpec()

项目:structured-output-ae    作者:sbelharbi    | 项目源码 | 文件源码
def plot_x_y_yhat(x, y, y_hat, xsz, ysz, binz=False):
    """Plot x, y and y_hat side by side."""
    plt.close("all")
    f = plt.figure(figsize=(15, 10.8), dpi=300)
    gs = gridspec.GridSpec(1, 3)
    if binz:
        y_hat = (y_hat > 0.5) * 1.
    ims = [x, y, y_hat]
    tils = [
        "x:" + str(xsz) + "x" + str(xsz),
        "y:" + str(ysz) + "x" + str(ysz),
        "yhat:" + str(ysz) + "x" + str(ysz)]
    for n, ti in zip([0, 1, 2], tils):
        f.add_subplot(gs[n])
        if n == 0:
            plt.imshow(ims[n], cmap=cm.Greys_r)
        else:
            plt.imshow(ims[n], cmap=cm.Greys_r)
        plt.title(ti)

    return f
项目:klineyes    作者:tenstone    | 项目源码 | 文件源码
def mfi(df):
    df['date'] = pd.to_datetime(df.date)

    fig = plt.figure(figsize=(16, 9))
    gs = GridSpec(3, 1) # 2 rows, 3 columns
    fig.suptitle(df['date'][-1:].values[0])
    fig.set_label('MFI')
    price = fig.add_subplot(gs[:2, 0])
    price.plot(df['date'], df['close'], color='blue')

    indicator = fig.add_subplot(gs[2, 0], sharex=price)
    indicator.plot(df['date'], df['mfi'], c='pink')
    indicator.plot(df['date'], [20.]*len(df['date']), c='green')
    indicator.plot(df['date'], [80.]*len(df['date']), c='orange')

    price.grid(True)
    indicator.grid(True)
    plt.tight_layout()
    plt.show()
项目:klineyes    作者:tenstone    | 项目源码 | 文件源码
def atr(df):
    '''
    Average True Range
    :param df:
    :return:
    '''
    df['date'] = pd.to_datetime(df.date)

    fig = plt.figure(figsize=(16, 9))
    gs = GridSpec(3, 1) # 2 rows, 3 columns
    fig.suptitle(df['date'][-1:].values[0])
    fig.set_label('ATR')
    price = fig.add_subplot(gs[:2, 0])
    price.plot(df['date'], df['close'], color='blue')

    indicator = fig.add_subplot(gs[2, 0], sharex=price)
    indicator.plot(df['date'], df['atr'], c='pink')
    # indicator.plot(df['date'], [20.]*len(df['date']), c='green')
    # indicator.plot(df['date'], [80.]*len(df['date']), c='orange')

    price.grid(True)
    indicator.grid(True)
    plt.tight_layout()
    plt.show()
项目:klineyes    作者:tenstone    | 项目源码 | 文件源码
def rocr(df):
    '''
    Average True Range
    :param df:
    :return:
    '''
    df['date'] = pd.to_datetime(df.date)

    fig = plt.figure(figsize=(16, 9))
    gs = GridSpec(3, 1) # 2 rows, 3 columns
    fig.suptitle(df['date'][-1:].values[0])
    fig.set_label('ATR')
    price = fig.add_subplot(gs[:2, 0])
    price.plot(df['date'], df['close'], color='blue')

    indicator = fig.add_subplot(gs[2, 0], sharex=price)
    indicator.plot(df['date'], df['rocr'], c='pink')
    # indicator.plot(df['date'], [20.]*len(df['date']), c='green')
    # indicator.plot(df['date'], [80.]*len(df['date']), c='orange')

    price.grid(True)
    indicator.grid(True)
    plt.tight_layout()
    plt.show()
项目:structured-output-ae    作者:sbelharbi    | 项目源码 | 文件源码
def plot_x_x_yhat(x, x_hat):
    """Plot x, y and y_hat side by side."""
    plt.close("all")
    f = plt.figure()  # figsize=(15, 10.8), dpi=300
    gs = gridspec.GridSpec(1, 2)
    ims = [x, x_hat]
    tils = [
        "xin:" + str(x.shape[0]) + "x" + str(x.shape[1]),
        "xout:" + str(x.shape[1]) + "x" + str(x_hat.shape[1])]
    for n, ti in zip([0, 1], tils):
        f.add_subplot(gs[n])
        plt.imshow(ims[n], cmap=cm.Greys_r)
        plt.title(ti)
        ax = f.gca()
        ax.set_axis_off()

    return f
项目:matplotlib-hep    作者:ibab    | 项目源码 | 文件源码
def make_split(ratio, gap=0.12):
    import matplotlib.pyplot as plt
    from matplotlib.gridspec import GridSpec
    from matplotlib.ticker import MaxNLocator
    cax = plt.gca()
    box = cax.get_position()
    xmin, ymin = box.xmin, box.ymin
    xmax, ymax = box.xmax, box.ymax
    gs = GridSpec(2, 1, height_ratios=[ratio, 1 - ratio], left=xmin, right=xmax, bottom=ymin, top=ymax)
    gs.update(hspace=gap)

    ax = plt.subplot(gs[0])
    plt.setp(ax.get_xticklabels(), visible=False)
    bx = plt.subplot(gs[1], sharex=ax)

    return ax, bx
项目:segyviewer    作者:Statoil    | 项目源码 | 文件源码
def set_plot_layout(self, layout_spec):
        rows, columns = layout_spec['dims']
        width = 0.025
        ratios = [(1.0 - width) / float(columns)] * columns
        ratios.append(width)

        grid_spec = gridspec.GridSpec(rows, columns + 1, width_ratios=ratios)

        for axes in self._axes:
            self.delaxes(axes)
        self._axes = [self.add_subplot(grid_spec[sub_spec]) for sub_spec in layout_spec['grid']]

        if self._colormap_axes is not None:
            self.delaxes(self._colormap_axes)
        self._colormap_axes = self.add_subplot(grid_spec[:, columns])

        self._current_layout = layout_spec
项目:tomato    作者:sertansenturk    | 项目源码 | 文件源码
def _create_figure():
        # create the figure with four subplots with different size
        # - 1st is for the predominant melody and performed notes
        # - 2nd is the pitch distribution and note models, it shares the y
        # axis with the 1st
        # - 3rd is the melodic progression, it shares the x axis with the 1st
        # - 4th is for the sections, it is on top the 3rd
        fig = plt.figure()
        gs = gridspec.GridSpec(2, 2, width_ratios=[6, 1], height_ratios=[4, 1])
        ax1 = fig.add_subplot(gs[0])  # pitch and notes
        ax2 = fig.add_subplot(gs[1], sharey=ax1)  # pitch dist. and note models
        ax4 = fig.add_subplot(gs[2])  # sections
        ax5 = fig.add_subplot(gs[3])  # makam, tempo, tonic, ahenk annotations
        ax3 = plt.twiny(ax4)  # melodic progression
        ax1.get_shared_x_axes().join(ax1, ax3)
        fig.subplots_adjust(hspace=0, wspace=0)
        return fig, ax1, ax2, ax3, ax4, ax5
项目:good-semi-bad-gan    作者:christiancosgrove    | 项目源码 | 文件源码
def plot(samples):
    width = min(12,int(np.sqrt(len(samples))))
    fig = plt.figure(figsize=(width, width))
    gs = gridspec.GridSpec(width, width)
    gs.update(wspace=0.05, hspace=0.05)

    for ind, sample in enumerate(samples):
        if ind >= width*width:
            break
        ax = plt.subplot(gs[ind])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        sample = sample * 0.5 + 0.5
        sample = np.transpose(sample, (1, 2, 0))
        plt.imshow(sample)

    return fig
项目:tefla    作者:litan    | 项目源码 | 文件源码
def subplots(img,row_count,col_count,crop_img):
    gs = gridspec.GridSpec(2, 2, width_ratios=[4, 3])
    plt.subplot(gs[0])
    imshow(plt, img)

    plt.subplot(gs[1])
    plt.plot(np.array(row_count))

    plt.subplot(gs[2])
    plt.plot(np.array(col_count))

    plt.subplot(gs[3])
    plt.imshow(crop_img)
    plt.show()

#crops an image from both dark and light background
#works best on a single color background
项目:MDT    作者:cbclab    | 项目源码 | 文件源码
def get_gridspec(self, figure, nmr_plots):
        rows = self.rows
        cols = self.cols

        if rows is None and cols is None:
            return AutoGridLayout(spacings=self.spacings).get_gridspec(figure, nmr_plots)

        if rows is None:
            rows = int(np.ceil(nmr_plots / cols))
        if cols is None:
            cols = int(np.ceil(nmr_plots / rows))

        if rows * cols < nmr_plots:
            cols = int(np.ceil(nmr_plots / rows))

        return GridLayoutSpecifier(GridSpec(rows, cols, **self.spacings), figure)
项目:3DGAN-Pytorch    作者:rimchang    | 项目源码 | 文件源码
def SavePloat_Voxels(voxels, path, iteration):
    voxels = voxels[:8].__ge__(0.5)
    fig = plt.figure(figsize=(32, 16))
    gs = gridspec.GridSpec(2, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(voxels):
        x, y, z = sample.nonzero()
        ax = plt.subplot(gs[i], projection='3d')
        ax.scatter(x, y, z, zdir='z', c='red')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
    plt.savefig(path + '/{}.png'.format(str(iteration).zfill(3)), bbox_inches='tight')
    plt.close()

    with open(path + '/{}.pkl'.format(str(iteration).zfill(3)), "wb") as f:
        pickle.dump(voxels, f, protocol=pickle.HIGHEST_PROTOCOL)
项目:chxanalys    作者:yugangzhang    | 项目源码 | 文件源码
def plot_aged_g2( g2_aged, tau=None,timeperframe=1, ylim=None, xlim=None):
    ''''A plot of g2 calculated from two-time'''
    fig = plt.figure(figsize=(8,10)) 
    age_center = list( sorted( g2_aged.keys() ) )
    gs = gridspec.GridSpec(len(age_center),1 ) 
    for n,i in enumerate( age_center):        
        ax = plt.subplot(gs[n]) 
        if tau is None:
            gx= np.arange(len(g2_aged[i])) * timeperframe
        else:
            gx=tau[i]
        marker = markers[n]
        c = colors[n]
        ax.plot( gx,g2_aged[i],  '-%s'%marker, c=c, label=r"$age= %.1f s$"%(i*timeperframe))
        ax.set_xscale('log')
        ax.legend(fontsize='large', loc='best' ) 
        ax.set_xlabel(r"$\tau $ $(s)$", fontsize=18) 
        ax.set_ylabel("g2")    
        if ylim is not None:
            ax.set_ylim( ylim )
        if xlim is not None:
            ax.set_ylim( xlim )   

#####################################
#get fout-time
项目:smp_base    作者:x75    | 项目源码 | 文件源码
def make_axes_from_grid(fig, gs):
    """Generate 2D array of subplot axes from a gridspec

    Args:
     - fig(matplotlib.figure.Figure): a matplotlib figure handle
     - gs(matplotlib.gridspec.GridSpec): the gridspec

    Returns:
     - list of lists (2D array) of subplot axes
    """
    axes = []
    (rows, cols) = gs.get_geometry()
    for row in range(rows):
        axes.append([])
        for col in range(cols):
            axes[-1].append(fig.add_subplot(gs[row, col]))
    return axes
项目:fenapack    作者:blechta    | 项目源码 | 文件源码
def _create_figure():
        fig = pyplot.figure()
        gs = gridspec.GridSpec(3, 1, height_ratios=[2, 2, 1], hspace=0.05)
        ax2 = fig.add_subplot(gs[1])
        ax1 = fig.add_subplot(gs[0], sharex=ax2)
        ax1.xaxis.set_label_position('top')
        ax1.xaxis.set_tick_params(labeltop='on', labelbottom='off')
        pyplot.setp(ax2.get_xticklabels(), visible=False)
        ax1.set_xscale('log')
        ax2.set_xscale('log')
        ax2.set_yscale('log')
        ax1.set_xlabel('Number dofs')
        ax1.set_ylabel('Number GMRES iterations')
        ax2.set_ylabel('CPU time')
        ax1.set_ylim(0, None, auto=True)
        ax2.set_ylim(0, None, auto=True)
        return fig, (ax1, ax2)
项目:CNN_UCMerced-LandUse_Caffe    作者:yangxue0827    | 项目源码 | 文件源码
def show_labes(image, probs, lables, true_label):
    gs = gridspec.GridSpec(1, 3)
    ax1 = plt.subplot(gs[1])
    x = list(reversed(lables))
    y = list(reversed(probs))
    colors = ['#edf8fb', '#ccece6', '#99d8c9', '#66c2a4', '#41ae76']
    # colors = ['#624ea7', 'g', 'yellow', 'k', 'maroon']
    # colors=list(reversed(colors))
    width = 0.4  # the width of the bars
    ind = np.arange(len(y))  # the x locations for the groups
    ax1.barh(ind, y, width, align='center', color=colors)
    ax1.set_yticks(ind + width / 2)
    ax1.set_yticklabels(x, minor=False)
    for i, v in enumerate(y):
        ax1.text(v, i, '%5.2f%%' % v, fontsize=14)
    plt.title('Probability Output', fontsize=20)
    ax2 = plt.subplot(gs[2])
    ax2.axis('off')
    ax2.imshow(image)
    #    fig = plt.gcf()
    #    fig.set_size_inches(8, 6)
    plt.title(true_label, fontsize=20)
    plt.show()
项目:CNN_UCMerced-LandUse_Caffe    作者:yangxue0827    | 项目源码 | 文件源码
def show_labes(image, probs, lables, true_label):
    fig = plt.figure()
    gs = gridspec.GridSpec(1, 3)
    ax1 = plt.subplot(gs[1])
    x = list(reversed(lables))
    y = list(reversed(probs))
    colors = ['#edf8fb', '#ccece6', '#99d8c9', '#66c2a4', '#41ae76']
    # colors = ['#624ea7', 'g', 'yellow', 'k', 'maroon']
    # colors=list(reversed(colors))
    width = 0.4  # the width of the bars
    ind = np.arange(len(y))  # the x locations for the groups
    ax1.barh(ind, y, width, align='center', color=colors)
    ax1.set_yticks(ind + width / 2)
    ax1.set_yticklabels(x, minor=False)
    for i, v in enumerate(y):
        ax1.text(v + 1, i, '%5.2f%%' % v, fontsize=14)
    plt.title('Probability Output', fontsize=20)
    ax2 = plt.subplot(gs[2])
    ax2.axis('off')
    ax2.imshow(image)
    plt.title(true_label, fontsize=20)
    plt.show()
    # if true_label != lables[0]:
    #     unique_filename = uuid.uuid4()
    #     fig.savefig('predit_worng/' + str(unique_filename) + '.jpg')
项目:mriqc    作者:poldracklab    | 项目源码 | 文件源码
def plot_dist(
        main_file, mask_file, xlabel, distribution=None, xlabel2=None,
        figsize=DINA4_LANDSCAPE):
    data = _get_values_inside_a_mask(main_file, mask_file)

    fig = plt.Figure(figsize=figsize)
    FigureCanvas(fig)

    gsp = GridSpec(2, 1)
    ax = fig.add_subplot(gsp[0, 0])
    sns.distplot(data.astype(np.double), kde=False, bins=100, ax=ax)
    ax.set_xlabel(xlabel)

    ax = fig.add_subplot(gsp[1, 0])
    sns.distplot(np.array(distribution).astype(np.double), ax=ax)
    cur_val = np.median(data)
    label = "{0!g}".format(cur_val)
    plot_vline(cur_val, label, ax=ax)
    ax.set_xlabel(xlabel2)

    return fig
项目:Land_Use_CNN    作者:BUPTLdy    | 项目源码 | 文件源码
def show_labes(image,probs,lables,true_label):
    gs = gridspec.GridSpec(1, 2,width_ratios=[1,1],height_ratios=[1,1])
    ax1 = plt.subplot(gs[0])
    x = list(reversed(lables))
    y = list(reversed(probs))  
    colors=['#edf8fb','#b2e2e2','#66c2a4','#2ca25f','#006d2c']
    #colors = ['#624ea7', 'g', 'yellow', 'k', 'maroon']
    #colors=list(reversed(colors))
    width = 0.4 # the width of the bars 
    ind = np.arange(len(y))  # the x locations for the groups
    ax1.barh(ind, y, width, align='center', color=colors)
    ax1.set_yticks(ind+width/2)
    ax1.set_yticklabels(x, minor=False)
    for i, v in enumerate(y):
        ax1.text(v, i, '%5.2f%%' %v,fontsize=14)
    plt.title('Probability Output',fontsize=20)
    ax2 = plt.subplot(gs[1])
    ax2.axis('off')
    ax2.imshow(image)
#    fig = plt.gcf()
#    fig.set_size_inches(8, 6)
    plt.title(true_label,fontsize=20)
    plt.show()
项目:vae_vpflows    作者:jmtomczak    | 项目源码 | 文件源码
def plot_images( args, x_sample, dir, file_name, size_x=3, size_y=3):

    fig = plt.figure(figsize=(size_x, size_y))
    # fig = plt.figure(1)
    gs = gridspec.GridSpec(size_x, size_y)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(x_sample):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(args.input_size[1], args.input_size[2]), cmap='Greys_r')

    plt.savefig(dir + file_name + '.png', bbox_inches='tight')
    plt.close(fig)

#=======================================================================================================================
项目:vae_vpflows    作者:jmtomczak    | 项目源码 | 文件源码
def plot_real( args, x_sample, dir, size_x=3, size_y=3):
    x_sample = x_sample.data.cpu().numpy()[:size_x*size_y]

    fig = plt.figure(figsize=(size_x, size_y))
    gs = gridspec.GridSpec(size_x, size_y)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(x_sample):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(args.input_size[1], args.input_size[2]), cmap='Greys_r')

    plt.savefig(dir + 'real.png', bbox_inches='tight')
    plt.close(fig)

#=======================================================================================================================
项目:vae_vpflows    作者:jmtomczak    | 项目源码 | 文件源码
def plot_reconstruction( args, samples, c, dir , size_x=3, size_y=3):
    samples = samples.data.cpu().numpy()[:size_x * size_y]

    fig = plt.figure(figsize=(size_x, size_y))
    gs = gridspec.GridSpec(size_x, size_y)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(args.input_size[1], args.input_size[2]), cmap='Greys_r')

    if not os.path.exists(dir + 'reconstruction/'):
        os.makedirs(dir + 'reconstruction/')

    plt.savefig(dir + 'reconstruction/{}.png'.format(str(c).zfill(3)), bbox_inches='tight')
    plt.close(fig)

#=======================================================================================================================
项目:vae_vpflows    作者:jmtomczak    | 项目源码 | 文件源码
def plot_generation( args, samples_mean, dir , size_x=3, size_y=3):
    # decode
    samples = samples_mean.data.cpu().numpy()[:size_x*size_y]

    fig = plt.figure(figsize=(size_x, size_y))
    gs = gridspec.GridSpec(size_x, size_y)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(args.input_size[1], args.input_size[2]), cmap='Greys_r')

    plt.savefig(dir + 'generation.png', bbox_inches='tight')
    plt.close(fig)

#=======================================================================================================================
项目:tschdata    作者:tum-lkn    | 项目源码 | 文件源码
def plot_int_buf_delay():
    fig = plt.figure(figsize=(7.5, 5.225))
    gs = gridspec.GridSpec(2, 1, height_ratios=[1, 1])

    ax0 = fig.add_subplot(gs[0])

    plot_intercepting_path_delays(ax0, shared=False)

    x_axis = list(range(1, 7))
    labels = ['I (l)', 'III (l)', 'IV (l)', 'I (b)', 'III (b)', 'IV (b)']
    plt.xticks(x_axis, labels)
    plt.ylabel('Delay, s')

    ax1 = fig.add_subplot(gs[1])

    plot_intercepting_path_delays(ax1, shared=True)

    labels = ['V (l)', 'VII (l)', 'VIII (l)', 'V (b)', 'VII(b)', 'VIII (b)']
    plt.xticks(x_axis, labels)

    plt.xlabel('Data sets')
    plt.ylabel('Delay, s')
    plt.show()
项目:AAE_pytorch    作者:fducau    | 项目源码 | 文件源码
def grid_plot2d(Q, P, data_loader, params):
    Q.eval()
    P.eval()

    cuda = params['cuda']

    z1 = Variable(torch.from_numpy(np.arange(-10, 10, 1.5).astype('float32')))
    z2 = Variable(torch.from_numpy(np.arange(-10, 10, 1.5).astype('float32')))
    if cuda:
        z1, z2 = z1.cuda(), z2.cuda()

    nx, ny = len(z1), len(z2)
    plt.subplot()
    gs = gridspec.GridSpec(nx, ny, hspace=0.05, wspace=0.05)

    for i, g in enumerate(gs):
        z = torch.cat((z1[i / ny], z2[i % nx])).resize(1, 2)
        x = P(z)

        ax = plt.subplot(g)
        img = np.array(x.data.tolist()).reshape(28, 28)
        ax.imshow(img, )
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_aspect('auto')
项目:generative-models    作者:wiseodd    | 项目源码 | 文件源码
def plot(samples, size, name):
    size = int(size)
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(size, size), cmap='Greys_r')

    plt.savefig('out/{}.png'.format(name), bbox_inches='tight')
    plt.close(fig)
项目:generative-models    作者:wiseodd    | 项目源码 | 文件源码
def plot(samples, size, name):
    size = int(size)
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(size, size), cmap='Greys_r')

    plt.savefig('out/{}.png'.format(name), bbox_inches='tight')
    plt.close(fig)
项目:generative-models    作者:wiseodd    | 项目源码 | 文件源码
def plot(samples, size, name):
    size = int(size)
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(size, size), cmap='Greys_r')

    plt.savefig('out/{}.png'.format(name), bbox_inches='tight')
    plt.close(fig)
项目:Contractive_Autoencoder_in_Pytorch    作者:avijit9    | 项目源码 | 文件源码
def samples_write(self, x, epoch):
        _, samples = self.forward(x)
        #pdb.set_trace()
        samples = samples.data.cpu().numpy()[:16]
        fig = plt.figure(figsize=(4, 4))
        gs = gridspec.GridSpec(4, 4)
        gs.update(wspace=0.05, hspace=0.05)
        for i, sample in enumerate(samples):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
        if not os.path.exists('out/'):
            os.makedirs('out/')
        plt.savefig('out/{}.png'.format(str(epoch).zfill(3)), bbox_inches='tight')
        #self.c += 1
        plt.close(fig)
项目:learning-class-invariant-features    作者:sbelharbi    | 项目源码 | 文件源码
def plot_x_y_yhat(x, y, y_hat, xsz, ysz, binz=False):
    """Plot x, y and y_hat side by side."""
    plt.close("all")
    f = plt.figure(figsize=(15, 10.8), dpi=300)
    gs = gridspec.GridSpec(1, 3)
    if binz:
        y_hat = (y_hat > 0.5) * 1.
    ims = [x, y, y_hat]
    tils = [
        "x:" + str(xsz) + "x" + str(xsz),
        "y:" + str(ysz) + "x" + str(ysz),
        "yhat:" + str(ysz) + "x" + str(ysz)]
    for n, ti in zip([0, 1, 2], tils):
        f.add_subplot(gs[n])
        if n == 0:
            plt.imshow(ims[n], cmap=cm.Greys_r)
        else:
            plt.imshow(ims[n], cmap=cm.Greys_r)
        plt.title(ti)

    return f
项目:Adversarial_Autoencoder    作者:Naresh1318    | 项目源码 | 文件源码
def generate_image_grid(sess, op):
    """
    Generates a grid of images by passing a set of numbers to the decoder and getting its output.
    :param sess: Tensorflow Session required to get the decoder output
    :param op: Operation that needs to be called inorder to get the decoder output
    :return: None, displays a matplotlib window with all the merged images.
    """
    x_points = np.arange(0, 1, 1.5).astype(np.float32)
    y_points = np.arange(0, 1, 1.5).astype(np.float32)

    nx, ny = len(x_points), len(y_points)
    plt.subplot()
    gs = gridspec.GridSpec(nx, ny, hspace=0.05, wspace=0.05)

    for i, g in enumerate(gs):
        z = np.concatenate(([x_points[int(i / ny)]], [y_points[int(i % nx)]]))
        z = np.reshape(z, (1, 2))
        x = sess.run(op, feed_dict={decoder_input: z})
        ax = plt.subplot(g)
        img = np.array(x.tolist()).reshape(28, 28)
        ax.imshow(img, cmap='gray')
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_aspect('auto')
    plt.show()
项目:Adversarial_Autoencoder    作者:Naresh1318    | 项目源码 | 文件源码
def generate_image_grid(sess, op):
    """
    Generates a grid of images by passing a set of numbers to the decoder and getting its output.
    :param sess: Tensorflow Session required to get the decoder output
    :param op: Operation that needs to be called inorder to get the decoder output
    :return: None, displays a matplotlib window with all the merged images.
    """
    x_points = np.arange(-10, 10, 1.5).astype(np.float32)
    y_points = np.arange(-10, 10, 1.5).astype(np.float32)

    nx, ny = len(x_points), len(y_points)
    plt.subplot()
    gs = gridspec.GridSpec(nx, ny, hspace=0.05, wspace=0.05)

    for i, g in enumerate(gs):
        z = np.concatenate(([x_points[int(i / ny)]], [y_points[int(i % nx)]]))
        z = np.reshape(z, (1, 2))
        x = sess.run(op, feed_dict={decoder_input: z})
        ax = plt.subplot(g)
        img = np.array(x.tolist()).reshape(28, 28)
        ax.imshow(img, cmap='gray')
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_aspect('auto')
    plt.show()
项目:bmlingam    作者:taku-y    | 项目源码 | 文件源码
def plot_gendata():
    """Plot artificial data with n_confounders=[0, 1, 6, 12].

    This program is used to check artificial data. 
    """
    n_samples = 200
    rng = np.random.RandomState(0)
    plt.figure(figsize=(10, 10))
    gs = gridspec.GridSpec(2, 2)

    # ---- Loop over the number of confonders ----
    for i, n_confounders in enumerate([0, 1, 6, 12]):
        # ---- Generate samples ----
        xs = gendata_latents(n_confounders, n_samples, rng)

        # ---- Plot samples ----
        ax = plt.subplot(gs[i])
        ax.scatter(xs[:, 0], xs[:, 1])
        ax.set_xlim(-10, 10)
        ax.set_ylim(-10, 10)
        ax.set_title('n_confounders=%d' % n_confounders)

    return
项目:hco-experiments    作者:zooniverse    | 项目源码 | 文件源码
def plot_pred_vs_image(img,preds_df,out_name):
    # function to plot predictions vs image
    f, axarr = plt.subplots(2, 1)
    plt.suptitle("ResNet50- PreTrained on ImageNet")
    axarr[0].imshow(img)
    sns.set_style("whitegrid")
    pl = sns.barplot(data = preds_df, x='Score', y='Species')
    axarr[1] = sns.barplot(data = preds_df, x='Score', y='Species',)
    axarr[0].autoscale(enable=False)
    axarr[0].get_xaxis().set_ticks([])
    axarr[0].get_yaxis().set_ticks([])
    axarr[1].autoscale(enable=False)
    gs = gridspec.GridSpec(2,1, width_ratios=[1],height_ratios=[1,0.1])
    plt.tight_layout()
    plt.savefig(out_name + '.png')


#########################
# Models
#########################

# load model
项目:icinco-code    作者:jacobnzw    | 项目源码 | 文件源码
def bot_demo():
    steps = 100
    mc_simulations = 1
    ssm = BearingsOnlyTracking(dt=0.1)
    x, z = ssm.simulate(steps, mc_sims=mc_simulations)
    # plt.plot(x[0, ...], color='b', alpha=0.15, label='state trajectory')
    # plt.plot(z[0, ...], color='k', alpha=0.25, ls='None', marker='.', label='measurements')
    plt.figure()
    g = gridspec.GridSpec(4, 1)
    plt.subplot(g[:2, 0])
    for i in range(mc_simulations):
        plt.plot(x[0, :, i], x[2, :, i], alpha=0.85, color='b')
    plt.subplot(g[2, 0])
    plt.plot(x[0, :, 0])
    plt.subplot(g[3, 0])
    plt.plot(x[2, :, 0])
    plt.show()
项目:siHMM    作者:Ardavans    | 项目源码 | 文件源码
def _get_axes(self,fig):
        # TODO is attaching these to the figure a good idea? why not save them
        # here and reuse them if we recognize the figure being passed in
        sz = self._fig_sz

        if hasattr(fig,'_feature_ax') and hasattr(fig,'_stateseq_axs'):
            return fig._feature_ax, fig._stateseq_axs
        else:
            if len(self.states_list) <= 2:
                gs = GridSpec(sz+len(self.states_list),1)

                feature_ax = plt.subplot(gs[:sz,:])
                stateseq_axs = [plt.subplot(gs[sz+idx]) for idx in range(len(self.states_list))]
            else:
                gs = GridSpec(1,2)
                sgs = GridSpecFromSubplotSpec(len(self.states_list),1,subplot_spec=gs[1])

                feature_ax = plt.subplot(gs[0])
                stateseq_axs = [plt.subplot(sgs[idx]) for idx in range(len(self.states_list))]

            for ax in stateseq_axs:
                ax.grid('off')

            fig._feature_ax, fig._stateseq_axs = feature_ax, stateseq_axs
            return feature_ax, stateseq_axs
项目:RiboCode    作者:xzt41    | 项目源码 | 文件源码
def plot_main(cds_start,cds_end,psites_array,orf_tstart,orf_tstop,outname):
    """
    the main plot function
    """
    plt.figure(figsize=(8,4))
    if cds_start is not None:
        gs = gridspec.GridSpec(3,1,height_ratios=[10,1,1],hspace=0.6,left=0.2,right=0.95)
    else:
        gs = gridspec.GridSpec(2,1,height_ratios=[11,1],hspace=0.6,left=0.2,right=0.95)

    ax1 = plt.subplot(gs[0])
    ax2 = plt.subplot(gs[1])
    plot_ORF(ax1,psites_array,orf_tstart)
    plot_annotation(ax2,psites_array.size,orf_tstart,orf_tstop,"Predicted","#3994FF")

    if cds_start is not None:
        ax3 = plt.subplot(gs[2])
        plot_annotation(ax3,psites_array.size,cds_start,cds_end,"Annotated","#006DD5")
    # plt.tight_layout()
    plt.savefig(outname + ".pdf")
项目:attention_ocr    作者:lightcaster    | 项目源码 | 文件源码
def plot_alpha(alpha, x, y):

    f, (a0, a1) = plt.subplots(2)
    gs = grd.GridSpec(2,1, wspace=0.01) #, height_ratios=[1, 4])
    a0 = plt.subplot(gs[0])

    a0.matshow(x.T, cmap=plt.cm.Greys_r) #, aspect='auto')

    probs = np.zeros_like(alpha)
    for i in range(len(alpha)):
        probs[i] = np.convolve(
            alpha[i], np.ones((2,))/2., mode='same')

    a1.matshow(alpha, interpolation='none', aspect='auto')
    xticks = np.argmax(probs, axis=1)

    a1.set_xticks(xticks)
    a1.set_xticklabels(y, fontsize=16)
    a1.grid(which='both') 
    plt.subplots_adjust(top=None, bottom=None, wspace=0.05, hspace=0.05)

    plt.show()
项目:polo    作者:adrianveres    | 项目源码 | 文件源码
def make_figure():
    gs = gridspec.GridSpec(5, 1,
                       height_ratios=[3, 1, 2, 3, 1],
                       hspace=0)

    data, Z, D = get_random_data(100, 0)
    order = leaves_list(Z)


    runtime, opt_Z = run_polo(Z, D)
    opt_order = leaves_list(opt_Z)

    fig = plt.figure(figsize=(5,5))
    axd1 = fig.add_subplot(gs[0,0])
    axd1.set_title("Random numbers, clustered using Ward's criterion, default linear ordering.", fontsize=9)
    dendrogram(Z, ax=axd1, link_color_func=lambda k: 'k')
    axd1.set_xticklabels(data[order].reshape(-1))
    axd1.set_xticks([])
    axd1.set_yticks([])

    axh1 = fig.add_subplot(gs[1,0])
    axh1.matshow(data[order].reshape((1,-1)), aspect='auto', cmap='RdBu', vmin=0, vmax=10000)
    axh1.set_xticks([])
    axh1.set_yticks([])

    axd2 = fig.add_subplot(gs[3,0])
    axd2.set_title("The same hierarchical clustering, arranged for optimal linear ordering.", fontsize=9)
    dendrogram(opt_Z, ax=axd2, link_color_func=lambda k: 'k')
    axd2.set_xticklabels(data[opt_order].reshape(-1))
    axd2.set_xticks([])
    axd2.set_yticks([])

    axh2 = fig.add_subplot(gs[4,0])
    axh2.matshow(data[opt_order].reshape((1,-1)), aspect='auto', cmap='RdBu', vmin=0, vmax=10000)
    axh2.set_xticks([])
    axh2.set_yticks([])

    fig.savefig('data/demo.png', dpi=130)
项目:vi_vae_gmm    作者:wangg12    | 项目源码 | 文件源码
def plot_images_and_clusters(images, clusters, epoch, save_path, ncol=10):
    '''use multiple images'''
    fig = plt.figure()#facecolor='black')
    images = np.squeeze(images, -1)

    nrow = int(np.ceil(images.shape[0] / float(ncol)))
    gs = gridspec.GridSpec(nrow, ncol,
                        width_ratios=[1]*ncol, height_ratios=[1]*nrow,
        #                         wspace=0.01, hspace=0.001,
        #                         top=0.95, bottom=0.05,
        #                         left=0.05, right=0.95
                        )
    gs.update(wspace=0, hspace=0)
    n = 0
    for i in range(10):
        images_i = images[clusters==i, :, :]
        if images_i.shape[0] == 0:
            continue

        for j in range(images_i.shape[0]):
            ax = plt.subplot(gs[n])
            n += 1
            plt.imshow(images_i[j,:], cmap='gray')
            plt.axis('off')
            ax.set_aspect('auto')
    plt.savefig(os.path.join(save_path, 'plot_gmvae_epoch_{}.png'.format(epoch)), dpi=fig.dpi)
项目:POT    作者:rflamary    | 项目源码 | 文件源码
def plot1D_mat(a, b, M, title=''):
    """ Plot matrix M  with the source and target 1D distribution

    Creates a subplot with the source distribution a on the left and
    target distribution b on the tot. The matrix M is shown in between.


    Parameters
    ----------
    a : np.array, shape (na,)
        Source distribution
    b : np.array, shape (nb,)
        Target distribution
    M : np.array, shape (na,nb)
        Matrix to plot
    """
    na, nb = M.shape

    gs = gridspec.GridSpec(3, 3)

    xa = np.arange(na)
    xb = np.arange(nb)

    ax1 = pl.subplot(gs[0, 1:])
    pl.plot(xb, b, 'r', label='Target distribution')
    pl.yticks(())
    pl.title(title)

    ax2 = pl.subplot(gs[1:, 0])
    pl.plot(a, xa, 'b', label='Source distribution')
    pl.gca().invert_xaxis()
    pl.gca().invert_yaxis()
    pl.xticks(())

    pl.subplot(gs[1:, 1:], sharex=ax1, sharey=ax2)
    pl.imshow(M, interpolation='nearest')
    pl.axis('off')

    pl.xlim((0, nb))
    pl.tight_layout()
    pl.subplots_adjust(wspace=0., hspace=0.2)
项目:phiplot    作者:grahamfindlay    | 项目源码 | 文件源码
def plot_concept_list(constellation, fig=None, **kwargs):
    """Vertically stack a constellation's concept plots (uses `plot_concept`).

    Examples:
        >>> big_mip = pyphi.compute.big_mip(sub)
        >>> plot_concept_list(big_mip.unpartitioned_constellation,
                             title_fmt='MP', state_fmt='1')
        >>> matplotlib.pyplot.show()

    Args:
        constellation (list(pyphi.models.Concept)): A list of concepts to plot.

    Keyword args:
        fig (matplotlib.Figure): A figure on which to plot. If *None*, a new
            figure is created and used. Default *None*.
        Any unmatched kwargs are passed to `plot_concept`.
    """
    DEFAULT_WIDTH = 8 # in inches
    DEFAULT_CONCEPT_HEIGHT = 1.75 # in inches
    n_concepts = len(constellation)
    if fig is None:
        fig = plt.figure(1, (DEFAULT_WIDTH, DEFAULT_CONCEPT_HEIGHT * n_concepts))


    gs = gridspec.GridSpec(n_concepts, 1)

    for concept_idx in range(n_concepts):
        plot_concept(constellation[concept_idx],
                     fig=fig,
                     subplot_spec=gs[concept_idx, 0],
                     **kwargs)

    fig.tight_layout()
项目:CGAN    作者:theflashsean1    | 项目源码 | 文件源码
def _plot(samples):
    fig = plt.figure(figsize=(10, 10))
    gs = gridspec.GridSpec(10, 10)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
    return fig
项目:nelpy    作者:nelpy    | 项目源码 | 文件源码
def __enter__(self):
        if not self.skip:
            self.fig = plt.figure(figsize=self.figsize,
                                  dpi=self.dpi,
                                  **self.kwargs)
            self.fig.npl_gs = gridspec.GridSpec(nrows=self.nrows,
                                                ncols=self.ncols)

            self.ax = np.array([self.fig.add_subplot(ss) for ss in self.fig.npl_gs])
            # self.fig, self.ax = plt.subplots(nrows=self.nrows,
            #                                  ncols=self.ncols,
            #                                  figsize=self.figsize,
            #                                  tight_layout=self.tight_layout,
            #                                  dpi=self.dpi,
            #                                  **self.kwargs)
            if len(self.ax) == 1:
                self.ax = self.ax[0]

            if self.tight_layout:
                self.fig.npl_gs.tight_layout(self.fig)

            # gs1.tight_layout(fig, rect=[0, 0.03, 1, 0.95])
            if self.fig != plt.gcf():
                self.clear()
                raise RuntimeError('Figure does not match active mpl figure')
            return self.fig, self.ax
        return -1, -1
项目:nelpy    作者:nelpy    | 项目源码 | 文件源码
def rastercountplot(spiketrain, nbins=50, **kwargs):
    fig = plt.figure(figsize=(14, 6))
    gs = gridspec.GridSpec(2, 1, hspace=0.01, height_ratios=[0.2,0.8])
    ax1 = plt.subplot(gs[0])
    ax2 = plt.subplot(gs[1])

    color = kwargs.get('color', None)
    if color is None:
        color = '0.4'

    ds = (spiketrain.support.stop - spiketrain.support.start)/nbins
    flattened = spiketrain.bin(ds=ds).flatten()
    steps = np.squeeze(flattened.data)
    stepsx = np.linspace(spiketrain.support.start, spiketrain.support.stop, num=flattened.n_bins)

#     ax1.plot(stepsx, steps, drawstyle='steps-mid', color='none');
    ax1.set_ylim([-0.5, np.max(steps)+1])
    rasterplot(spiketrain, ax=ax2, **kwargs)

    utils.clear_left_right(ax1)
    utils.clear_top_bottom(ax1)
    utils.clear_top(ax2)

    ax1.fill_between(stepsx, steps, step='mid', color=color)

    utils.sync_xlims(ax1, ax2)

    return ax1, ax2
项目:geepee    作者:thangbui    | 项目源码 | 文件源码
def newfig(width):
    plt.clf()
    fig = plt.figure(figsize=figsize(width))

    gs = gridspec.GridSpec(2, 2,
                       width_ratios=[1,4],
                       height_ratios=[4,1]
                       )

    ax1 = plt.subplot(gs[0])
    ax2 = plt.subplot(gs[1])
    ax3 = plt.subplot(gs[3])

    return fig, (ax1, ax2, ax3)
项目:MDT    作者:cbclab    | 项目源码 | 文件源码
def __init__(self, gridspec, figure, positions=None):
        """Create a grid layout specifier using the given gridspec and the given figure.

        Args:
            gridspec (GridSpec): the gridspec to use
            figure (Figure): the figure to generate subplots for
            positions (:class:`list`): if given, a list with grid spec indices for every requested axis
                can be logical indices or (x, y) coordinate indices (choose one and stick with it).
        """
        self.gridspec = gridspec
        self.figure = figure
        self.positions = positions
项目:MDT    作者:cbclab    | 项目源码 | 文件源码
def get_gridspec(self, figure, nmr_plots):
        rows, cols = self._get_square_size(nmr_plots)
        return GridLayoutSpecifier(GridSpec(rows, cols, **self.spacings), figure)
项目:MDT    作者:cbclab    | 项目源码 | 文件源码
def get_gridspec(self, figure, nmr_plots):
        rows, columns, positions = self._get_size_and_position(nmr_plots)
        return GridLayoutSpecifier(GridSpec(rows, columns, **self.spacings), figure, positions=positions)
项目:MDT    作者:cbclab    | 项目源码 | 文件源码
def get_gridspec(self, figure, nmr_plots):
        return GridLayoutSpecifier(GridSpec(nmr_plots, 1, **self.spacings), figure)