Python mpl_toolkits.axes_grid1 模块,ImageGrid() 实例源码

我们从Python开源项目中,提取了以下9个代码示例,用于说明如何使用mpl_toolkits.axes_grid1.ImageGrid()

项目:IntroToDeepLearning    作者:robb-brown    | 项目源码 | 文件源码
def plotFields(layer,fieldShape=None,channel=None,figOffset=1,cmap=None,padding=0.01):
    # Receptive Fields Summary
    try:
        W = layer.W
    except:
        W = layer
    wp = W.eval().transpose();
    if len(np.shape(wp)) < 4:       # Fully connected layer, has no shape
        fields = np.reshape(wp,list(wp.shape[0:-1])+fieldShape) 
    else:           # Convolutional layer already has shape
        features, channels, iy, ix = np.shape(wp)
        if channel is not None:
            fields = wp[:,channel,:,:]
        else:
            fields = np.reshape(wp,[features*channels,iy,ix])

    perRow = int(math.floor(math.sqrt(fields.shape[0])))
    perColumn = int(math.ceil(fields.shape[0]/float(perRow)))

    fig = mpl.figure(figOffset); mpl.clf()

    # Using image grid
    from mpl_toolkits.axes_grid1 import ImageGrid
    grid = ImageGrid(fig,111,nrows_ncols=(perRow,perColumn),axes_pad=padding,cbar_mode='single')
    for i in range(0,np.shape(fields)[0]):
        im = grid[i].imshow(fields[i],cmap=cmap); 

    grid.cbar_axes[0].colorbar(im)
    mpl.title('%s Receptive Fields' % layer.name)

    # old way
    # fields2 = np.vstack([fields,np.zeros([perRow*perColumn-fields.shape[0]] + list(fields.shape[1:]))])
    # tiled = []
    # for i in range(0,perColumn*perRow,perColumn):
    #   tiled.append(np.hstack(fields2[i:i+perColumn]))
    # 
    # tiled = np.vstack(tiled)
    # mpl.figure(figOffset); mpl.clf(); mpl.imshow(tiled,cmap=cmap); mpl.title('%s Receptive Fields' % layer.name); mpl.colorbar();
    mpl.figure(figOffset+1); mpl.clf(); mpl.imshow(np.sum(np.abs(fields),0),cmap=cmap); mpl.title('%s Total Absolute Input Dependency' % layer.name); mpl.colorbar()
项目:IntroToDeepLearning    作者:robb-brown    | 项目源码 | 文件源码
def plotFields(layer,fieldShape=None,channel=None,maxFields=25,figName='ReceptiveFields',cmap=None,padding=0.01):
    # Receptive Fields Summary
    W = layer.W
    wp = W.eval().transpose();
    if len(np.shape(wp)) < 4:       # Fully connected layer, has no shape
        fields = np.reshape(wp,list(wp.shape[0:-1])+fieldShape)
    else:           # Convolutional layer already has shape
        features, channels, iy, ix = np.shape(wp)
        if channel is not None:
            fields = wp[:,channel,:,:]
        else:
            fields = np.reshape(wp,[features*channels,iy,ix])

    fieldsN = min(fields.shape[0],maxFields)
    perRow = int(math.floor(math.sqrt(fieldsN)))
    perColumn = int(math.ceil(fieldsN/float(perRow)))

    fig = mpl.figure(figName); mpl.clf()

    # Using image grid
    from mpl_toolkits.axes_grid1 import ImageGrid
    grid = ImageGrid(fig,111,nrows_ncols=(perRow,perColumn),axes_pad=padding,cbar_mode='single')
    for i in range(0,fieldsN):
        im = grid[i].imshow(fields[i],cmap=cmap);

    grid.cbar_axes[0].colorbar(im)
    mpl.title('%s Receptive Fields' % layer.name)

    # old way
    # fields2 = np.vstack([fields,np.zeros([perRow*perColumn-fields.shape[0]] + list(fields.shape[1:]))])
    # tiled = []
    # for i in range(0,perColumn*perRow,perColumn):
    #   tiled.append(np.hstack(fields2[i:i+perColumn]))
    #
    # tiled = np.vstack(tiled)
    # mpl.figure(figOffset); mpl.clf(); mpl.imshow(tiled,cmap=cmap); mpl.title('%s Receptive Fields' % layer.name); mpl.colorbar();
    mpl.figure(figName+' Total'); mpl.clf(); mpl.imshow(np.sum(np.abs(fields),0),cmap=cmap); mpl.title('%s Total Absolute Input Dependency' % layer.name); mpl.colorbar()
项目:discgen    作者:vdumoulin    | 项目源码 | 文件源码
def plot_image_grid(images, num_rows, num_cols, save_path=None):
    """Plots images in a grid.

    Parameters
    ----------
    images : numpy.ndarray
        Images to display, with shape
        ``(num_rows * num_cols, num_channels, height, width)``.
    num_rows : int
        Number of rows for the image grid.
    num_cols : int
        Number of columns for the image grid.
    save_path : str, optional
        Where to save the image grid. Defaults to ``None``,
        which causes the grid to be displayed on screen.

    """
    figure = pyplot.figure()
    grid = ImageGrid(figure, 111, (num_rows, num_cols), axes_pad=0.1)

    for image, axis in zip(images, grid):
        axis.imshow(image.transpose(1, 2, 0), interpolation='nearest')
        axis.set_yticklabels(['' for _ in range(image.shape[1])])
        axis.set_xticklabels(['' for _ in range(image.shape[2])])
        axis.axis('off')

    if save_path is None:
        pyplot.show()
    else:
        pyplot.savefig(save_path, transparent=True, bbox_inches='tight')
项目:Neural-Photo-Editor    作者:ajbrock    | 项目源码 | 文件源码
def plot_image_grid(images, num_rows, num_cols, save_path=None):
    """Plots images in a grid.

    Parameters
    ----------
    images : numpy.ndarray
        Images to display, with shape
        ``(num_rows * num_cols, num_channels, height, width)``.
    num_rows : int
        Number of rows for the image grid.
    num_cols : int
        Number of columns for the image grid.
    save_path : str, optional
        Where to save the image grid. Defaults to ``None``,
        which causes the grid to be displayed on screen.

    """
    figure = pyplot.figure()
    grid = ImageGrid(figure, 111, (num_rows, num_cols), axes_pad=0.1)

    for image, axis in zip(images, grid):
        axis.imshow(image.transpose(1, 2, 0), interpolation='nearest')
        axis.set_yticklabels(['' for _ in range(image.shape[1])])
        axis.set_xticklabels(['' for _ in range(image.shape[2])])
        axis.axis('off')

    if save_path is None:
        pyplot.show()
    else:
        pyplot.savefig(save_path, transparent=True, bbox_inches='tight',dpi=212)
        pyplot.close()
项目:WassersteinGAN.tensorflow    作者:shekkizh    | 项目源码 | 文件源码
def save_imshow_grid(images, logs_dir, filename, shape):
    """
    Plot images in a grid of a given shape.
    """
    fig = plt.figure(1)
    grid = ImageGrid(fig, 111, nrows_ncols=shape, axes_pad=0.05)

    size = shape[0] * shape[1]
    for i in trange(size, desc="Saving images"):
        grid[i].axis('off')
        grid[i].imshow(images[i])

    plt.savefig(os.path.join(logs_dir, filename))
项目:evaluation-toolkit    作者:lightfield-analysis    | 项目源码 | 文件源码
def _get_grids(fig, rows, cols, axes_pad=0):
        grids = []
        for row in range(rows):
            grid_id = int("%d%d%d" % (rows, 1, row + 1))
            grid = ImageGrid(fig, grid_id,
                             nrows_ncols=(1, cols),
                             axes_pad=(0.05, axes_pad),
                             share_all=True,
                             cbar_location="right",
                             cbar_mode="single",
                             cbar_size="10%",
                             cbar_pad="5%")
            grids.append(grid)
        return grids
项目:GAN    作者:kunrenzhilu    | 项目源码 | 文件源码
def save_imshow_grid(images, logs_dir, filename, shape):
    """
    Plot images in a grid of a given shape.
    """
    pickle.dump(images, open(os.path.join(logs_dir, "image.pk"), "wb"))
    fig = plt.figure(1)
    grid = ImageGrid(fig, 111, nrows_ncols=shape, axes_pad=0.05)

    size = shape[0] * shape[1]
    for i in trange(size, desc="Saving images"):
        grid[i].axis('off')
        grid[i].imshow(images[i])
    Image.fromarray(images[i]).save(os.path.join(logs_dir,str(i)),"jpeg")

    plt.savefig(os.path.join(logs_dir, filename))
项目:IllustrationGAN    作者:tdrussell    | 项目源码 | 文件源码
def main(argv=None):
    input.init_dataset_constants()
    num_images = GRID[0] * GRID[1]
    FLAGS.batch_size = num_images
    with tf.Graph().as_default():
        g_template = model.generator_template()
        z = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.z_size])
        #np.random.seed(1337) # generate same random numbers each time
        noise = np.random.normal(size=(FLAGS.batch_size, FLAGS.z_size))
        with pt.defaults_scope(phase=pt.Phase.test):
            gen_images_op, _ = pt.construct_all(g_template, input=z)

        sess = tf.Session()
        init_variables(sess)
        gen_images, = sess.run([gen_images_op], feed_dict={z: noise})
        gen_images = (gen_images + 1) / 2

        sess.close()

        fig = plt.figure(1)
        grid = ImageGrid(fig, 111,
                         nrows_ncols=GRID,
                         axes_pad=0.1)
        for i in xrange(num_images):
            im = gen_images[i]
            axis = grid[i]
            axis.axis('off')
            axis.imshow(im)

        plt.show()
        fig.savefig('montage.png', dpi=100, bbox_inches='tight')
项目:ActiveBoundary    作者:MiriamHu    | 项目源码 | 文件源码
def generate_images_line_save(self, line_segment, query_id, image_original_space=None):
        """
        ID of query point from which query line was generated is
        added to the filename of the saved line query.
        :param line_segment:
        :param query_id:
        :return:
        """
        try:
            if image_original_space is not None:
                x = self.generative_model.decode(image_original_space.T)
            else:
                x = self.generative_model.decode(to_vector(self.dataset.data["features"][
                                                               query_id]).T)  # comes from dataset.data["features"], so is already in original space in which ALI operates.
            save_path = os.path.join(self.save_path_queries, "pointquery_%d_%d.png" % (self.n_queries + 1, query_id))
            if x.shape[1] == 1:
                plt.imsave(save_path, x[0, 0, :, :], cmap=cm.Greys)
            else:
                plt.imsave(save_path, x[0, :, :, :].transpose(1, 2, 0), cmap=cm.Greys_r)

            decoded_images = self.generative_model.decode(self.dataset.scaling_transformation.inverse_transform(
                line_segment))  # Transform to original space, in which ALI operates.
            figure = plt.figure()
            grid = ImageGrid(figure, 111, (1, decoded_images.shape[0]), axes_pad=0.1)
            for image, axis in zip(decoded_images, grid):
                if image.shape[0] == 1:
                    axis.imshow(image[0, :, :].squeeze(),
                                cmap=cm.Greys, interpolation='nearest')
                else:
                    axis.imshow(image.transpose(1, 2, 0).squeeze(),
                                cmap=cm.Greys_r, interpolation='nearest')
                axis.set_yticklabels(['' for _ in range(image.shape[1])])
                axis.set_xticklabels(['' for _ in range(image.shape[2])])
                axis.axis('off')
            save_path = os.path.join(self.save_path_queries, "linequery_%d_%d.pdf" % (self.n_queries + 1, query_id))
            plt.savefig(save_path, transparent=True, bbox_inches='tight')
        except Exception as e:
            print "EXCEPTION:", traceback.format_exc()
            raise e