Python matplotlib.pyplot 模块,pcolor() 实例源码

我们从Python开源项目中,提取了以下34个代码示例,用于说明如何使用matplotlib.pyplot.pcolor()

项目:tap    作者:mfouesneau    | 项目源码 | 文件源码
def plot_density_map(x, y, xbins, ybins, Nlevels=4, cbar=True, weights=None):

    Z = np.histogram2d(x, y, bins=(xbins, ybins), weights=weights)[0].astype(float).T

    # central values
    lt = get_centers_from_bins(xbins)
    lm = get_centers_from_bins(ybins)
    cX, cY = np.meshgrid(lt, lm)
    X, Y = np.meshgrid(xbins, ybins)

    im = plt.pcolor(X, Y, Z, cmap=plt.cm.Blues)
    plt.contour(cX, cY, Z, levels=nice_levels(Z, Nlevels), cmap=plt.cm.Greys_r)

    if cbar:
        cb = plt.colorbar(im)
    else:
        cb = None
    plt.xlim(xbins[0], xbins[-1])
    plt.ylim(ybins[0], ybins[-1])

    try:
        plt.tight_layout()
    except Exception as e:
        print(e)
    return plt.gca(), cb
项目:pyballd    作者:Yurlungur    | 项目源码 | 文件源码
def plot_interpolation(orderx,ordery):
    s = PseudoSpectralDiscretization2D(orderx,XMIN,XMAX,
                                ordery,YMIN,YMAX)
    Xc,Yc = s.get_x2d()
    x = np.linspace(XMIN,XMAX,100)
    y = np.linspace(YMIN,YMAX,100)
    Xf,Yf = np.meshgrid(x,y,indexing='ij')
    f_coarse = f(Xc,Yc)
    f_interpolator = s.to_continuum(f_coarse)
    f_num = f_interpolator(Xf,Yf)
    plt.pcolor(Xf,Yf,f_num)
    cb = plt.colorbar()
    cb.set_label('interpolated function',fontsize=16)
    plt.xlabel('x')
    plt.ylabel('y')
    for postfix in ['.png','.pdf']:
        name = 'orthopoly_interpolated_function'+postfix
        if USE_FIGS_DIR:
            name = 'figs/' + name
        plt.savefig(name,
                    bbox_inches='tight')
    plt.clf()
项目:snake    作者:rhinech    | 项目源码 | 文件源码
def load_data():
    """Draw the Mott lobes."""

    res = np.load(r'data_%d.npy' % GRID_SIZE)
    x = res[:, 0]
    y = res[:, 1]
    z = []
    for i, entry in enumerate(res):
        z.append(kinetic_energy(entry[2:], -1.))
    plt.pcolor(
        np.reshape(x, (GRID_SIZE, GRID_SIZE)),
        np.reshape(y, (GRID_SIZE, GRID_SIZE)),
        np.reshape(z, (GRID_SIZE, GRID_SIZE))
    )
    plt.xlabel('$dt/U$')
    plt.ylabel('$\mu/U$')
    plt.show()
项目:AnomalyDetection    作者:JayZhuCoding    | 项目源码 | 文件源码
def plot_training_parameters(self):
        fr = open("training_param.csv", "r")
        fr.readline()
        lines = fr.readlines()
        fr.close()
        n = 100
        nu = np.empty(n, dtype=np.float64)
        gamma = np.empty(n, dtype=np.float64)
        diff = np.empty([n, n], dtype=np.float64)
        for row in range(len(lines)):
            m = lines[row].strip().split(",")
            i = row / n
            j = row % n
            nu[i] = Decimal(m[0])
            gamma[j] = Decimal(m[1])
            diff[i][j] = Decimal(m[2])
        plt.pcolor(gamma, nu, diff, cmap="coolwarm")
        plt.title("The Difference of Guassian Classifier with Different nu, gamma")
        plt.xlabel("gamma")
        plt.ylabel("nu")
        plt.xscale("log")
        plt.yscale("log")
        plt.colorbar()
        plt.show()
项目:crypto-forcast    作者:7yl4r    | 项目源码 | 文件源码
def plotImage(dta, saveFigName):
    plt.clf()
    dx, dy = 1, 1
    # generate 2 2d grids for the x & y bounds
    with np.errstate(invalid='ignore'):
        y, x = np.mgrid[
            slice(0, len(dta)   , dx),
            slice(0, len(dta[0]), dy)
        ]
        z = dta
        z_min, z_max = -np.abs(z).max(), np.abs(z).max()

        #try:
        c = plt.pcolormesh(x, y, z, cmap='hsv', vmin=z_min, vmax=z_max)
        #except ??? as err:  # data not regular?
        #   c = plt.pcolor(x, y, z, cmap='hsv', vmin=z_min, vmax=z_max)
        d = plt.colorbar(c, orientation='vertical')
        lx = plt.xlabel("index")
        ly = plt.ylabel("season length")
        plt.savefig(str(saveFigName))
项目:vsmlib    作者:undertherain    | 项目源码 | 文件源码
def plot_heat(ax, m, xlabels, ylabels):
    # norm = Normalize(-10,10,False)
    norm = MidpointNormalize(midpoint=0)
    ax.set_aspect('equal')
    plt.xticks(rotation=90)
    ax.set_xticks(np.arange(m.shape[1]) + 0.5, minor=False)
    ax.set_yticks(np.arange(m.shape[0]) + 0.5, minor=False)
    ax.set_xticklabels(xlabels, minor=False)
    ax.set_yticklabels(ylabels, minor=False)
    for tic in ax.xaxis.get_major_ticks():
        tic.tick1On = tic.tick2On = False
    for tic in ax.yaxis.get_major_ticks():
        tic.tick1On = tic.tick2On = False
    # ax.set_frame_on(False)
    heatmap = plt.pcolor(np.array(m), norm=norm, cmap=mpl.cm.RdBu, edgecolors="black")
    # heatmap = plt.pcolor(np.array(m), cmap=mpl.cm.RdBu, edgecolors="black")
    # im = ax.imshow(np.array(m), norm=norm, cmap=plt.cm.seismic, interpolation='none')
    # fig.colorbar(im)
    cb = plt.colorbar(heatmap, orientation='horizontal', shrink=1, aspect=40)
    # cb=plt.colorbar()
    # cb.fraction=0.1
项目:iota    作者:amaneureka    | 项目源码 | 文件源码
def Pcolor(xs, ys, zs, pcolor=True, contour=False, **options):
    """Makes a pseudocolor plot.

    xs:
    ys:
    zs:
    pcolor: boolean, whether to make a pseudocolor plot
    contour: boolean, whether to make a contour plot
    options: keyword args passed to pyplot.pcolor and/or pyplot.contour
    """
    _Underride(options, linewidth=3, cmap=matplotlib.cm.Blues)

    X, Y = np.meshgrid(xs, ys)
    Z = zs

    x_formatter = matplotlib.ticker.ScalarFormatter(useOffset=False)
    axes = pyplot.gca()
    axes.xaxis.set_major_formatter(x_formatter)

    if pcolor:
        pyplot.pcolormesh(X, Y, Z, **options)

    if contour:
        cs = pyplot.contour(X, Y, Z, **options)
        pyplot.clabel(cs, inline=1, fontsize=10)
项目:ThinkX    作者:AllenDowney    | 项目源码 | 文件源码
def Pcolor(xs, ys, zs, pcolor=True, contour=False, **options):
    """Makes a pseudocolor plot.

    xs:
    ys:
    zs:
    pcolor: boolean, whether to make a pseudocolor plot
    contour: boolean, whether to make a contour plot
    options: keyword args passed to plt.pcolor and/or plt.contour
    """
    _Underride(options, linewidth=3, cmap=matplotlib.cm.Blues)

    X, Y = np.meshgrid(xs, ys)
    Z = zs

    x_formatter = matplotlib.ticker.ScalarFormatter(useOffset=False)
    axes = plt.gca()
    axes.xaxis.set_major_formatter(x_formatter)

    if pcolor:
        plt.pcolormesh(X, Y, Z, **options)

    if contour:
        cs = plt.contour(X, Y, Z, **options)
        plt.clabel(cs, inline=1, fontsize=10)
项目:pyballd    作者:Yurlungur    | 项目源码 | 文件源码
def plot_test_function(orderx,ordery):
    s = PseudoSpectralDiscretization2D(orderx,XMIN,XMAX,
                                ordery,YMIN,YMAX)
    X,Y = s.get_x2d()
    f_ana = f(X,Y)
    plt.pcolor(X,Y,f_ana)
    plt.xlabel('x',fontsize=16)
    plt.ylabel('y',fontsize=16)
    plt.xlim(XMIN,XMAX)
    plt.ylim(YMIN,YMAX)
    cb = plt.colorbar()
    cb.set_label(label=r'$\cos(x)\sin(2 y)$',fontsize=16)
    for postfix in ['.png','.pdf']:
        name = 'test_function'+postfix
        if USE_FIGS_DIR:
            name = 'figs/' + name
        plt.savefig(name,
                    bbox_inches='tight')
    plt.clf()
项目:almond-nnparser    作者:Stanford-Mobisocial-IoT-Lab    | 项目源码 | 文件源码
def show_heatmap(x, y, attention):
    #print attention[:len(y),:len(x)]
    #print attention[:len(y),:len(x)].shape
    #data = np.transpose(attention[:len(y),:len(x)])
    data = attention[:len(y),:len(x)]
    x, y = y, x

    #ax = plt.axes(aspect=0.4)
    ax = plt.axes()
    heatmap = plt.pcolor(data, cmap=plt.cm.Blues)

    xticks = np.arange(len(y)) + 0.5
    xlabels = y
    yticks = np.arange(len(x)) + 0.5
    ylabels = x
    plt.xticks(xticks, xlabels, rotation='vertical')
    ax.set_yticks(yticks)
    ax.set_yticklabels(ylabels)

    # make it look less like a scatter plot and more like a colored table
    ax.tick_params(axis='both', length=0)
    ax.invert_yaxis()
    ax.xaxis.tick_top()

    plt.colorbar(heatmap)

    plt.show()
    #plt.savefig('./attention-out.pdf')
项目:IRL-maxent    作者:harpribot    | 项目源码 | 文件源码
def main(grid_size, discount, n_trajectories, epochs, learning_rate, trajectory_length,
         trust, expert_type, random_start):
    """
    Run maximum entropy inverse reinforcement learning on the gridworld MDP.

    Plots the reward function.

    grid_size: Grid size. int.
    discount: MDP discount factor. float.
    n_trajectories: Number of sampled trajectories. int.
    epochs: Gradient descent iterations. int.
    learning_rate: Gradient descent learning rate. float.
    """

    wind = 1 - trust

    gw = gridworld.Gridworld(grid_size, wind, discount, expert_type)
    trajectories = gw.generate_trajectories(n_trajectories,
                                            trajectory_length,
                                            gw.optimal_policy, random_start=random_start)
    feature_matrix = gw.feature_matrix()
    ground_r = np.array([gw.reward(s) for s in range(gw.n_states)])
    r = maxent.irl(feature_matrix, gw.n_actions, discount,
                   gw.transition_probability, trajectories, epochs, learning_rate)

    print r.reshape((grid_size, grid_size))

    plt.subplot(1, 2, 1)
    plt.pcolor(ground_r.reshape((grid_size, grid_size)))
    plt.colorbar()
    plt.title("Groundtruth reward")
    plt.subplot(1, 2, 2)
    plt.pcolor(r.reshape((grid_size, grid_size)))
    plt.colorbar()
    plt.title("Recovered reward")
    plt.show()
项目:KATE    作者:hugochan    | 项目源码 | 文件源码
def heatmap(data, save_file='heatmap.png'):
    ax = plt.figure().gca()
    ax.yaxis.set_major_locator(MaxNLocator(integer=True))
    ax.yaxis.set_major_locator(MultipleLocator(5))
    plt.pcolor(data, cmap=plt.cm.jet)
    plt.savefig(save_file)
    # plt.show()
项目:nn4nlp-code    作者:neubig    | 项目源码 | 文件源码
def plot_attention(src_words, trg_words, attention_matrix, file_name=None):
  """This takes in source and target words and an attention matrix (in numpy format)
  and prints a visualization of this to a file.
  :param src_words: a list of words in the source
  :param trg_words: a list of target words
  :param attention_matrix: a two-dimensional numpy array of values between zero and one,
    where rows correspond to source words, and columns correspond to target words
  :param file_name: the name of the file to which we write the attention
  """
  fig, ax = plt.subplots()
  #a lazy, rough, approximate way of making the image large enough
  fig.set_figwidth(int(len(trg_words)*.6))

  # put the major ticks at the middle of each cell
  ax.set_xticks(np.arange(attention_matrix.shape[1]) + 0.5, minor=False)
  ax.set_yticks(np.arange(attention_matrix.shape[0]) + 0.5, minor=False)
  ax.invert_yaxis()

  # label axes by words
  ax.set_xticklabels(trg_words, minor=False)
  ax.set_yticklabels(src_words, minor=False)
  ax.xaxis.tick_top()
  plt.setp(ax.get_xticklabels(), rotation=50, horizontalalignment='right')
  # draw the heatmap
  plt.pcolor(attention_matrix, cmap=plt.cm.Blues, vmin=0, vmax=1)
  plt.colorbar()

  if file_name != None:
    plt.savefig(file_name, dpi=100)
  else:
    plt.show()
  plt.close()
项目:quoll    作者:LanguageMachines    | 项目源码 | 文件源码
def visualize_document_topics_heatmap(self, outfile, set_topics=False):
        self.sort_doctopics_groups()
        doctopics_raw_hm = numpy.rot90(self.document_topics_raw)
        rows, columns = doctopics_raw_hm.shape
        rownames = self.topic_labels
        columnnames = self.document_names
        pyplot.pcolor(doctopics_raw_hm, norm=None, cmap='Blues')
        pyplot.gca().invert_yaxis()
        if self.group_names:
            ticks_groups = []
            bounds = []
            current_group = False
            start = 0
            for i,doc in enumerate(self.document_names):
                group = self.document_group_dict[doc]
                if group != current_group:
                    if i != 0:
                        bounds.append(i-1)
                        ticks_groups[start+int((i-start)/2)] = current_group
                    current_group = group
                    start=i
                ticks_groups.append('')
            ticks_groups[start+int((i-start)/2)] = current_group
            pyplot.xticks(numpy.arange(columns)+0.5,ticks_groups, fontsize=11)
            if set_topics:
                for index in set_topics:
                    pyplot.axhline(y=index)
                topic_names = self.return_topic_names(set_topics)
                pyplot.yticks(set_topics,topic_names,fontsize=8)
            else:
                pyplot.yticks(numpy.arange(rows)+0.5, rownames, fontsize=8)
            for bound in bounds:
                pyplot.axvline(x=bound)
        pyplot.colorbar(cmap='Blues')
        pyplot.savefig(outfile)
        pyplot.clf()
项目:Sisyphus    作者:davidbrandfonbrener    | 项目源码 | 文件源码
def show_W_rec(model, sess):
    if model.dale_ratio:
        plt.pcolor(np.matmul(abs(model.W_rec.eval(session=sess)) * model.recurrent_connectivity_mask, model.dale_rec))
    else:
        plt.pcolor(model.W_rec.eval(session=sess))
    plt.colorbar()
    plt.show()
项目:Sisyphus    作者:davidbrandfonbrener    | 项目源码 | 文件源码
def show_W_in(model, sess):
    if model.dale_ratio:
        plt.pcolor(abs(model.W_in.eval(session=sess)) * model.input_connectivity_mask)
    else:
        plt.pcolor(model.W_in.eval(session=sess))
    plt.colorbar()
    plt.show()
项目:Sisyphus    作者:davidbrandfonbrener    | 项目源码 | 文件源码
def show_W_out(model, sess):
    if model.dale_ratio:
        plt.pcolor(np.matmul(abs(model.W_out.eval(session=sess)) * model.output_connectivity_mask, model.dale_out))
    else:
        plt.pcolor(model.W_out.eval(session=sess))
    plt.colorbar()
    plt.show()
项目:Parser-v1    作者:tdozat    | 项目源码 | 文件源码
def savefigs(self, sess, optimizer=False):
    """"""

    import gc
    import matplotlib as mpl
    mpl.use('Agg')
    import matplotlib.pyplot as plt
    matdir = os.path.join(self.save_dir, 'matrices')
    if not os.path.isdir(matdir):
      os.mkdir(matdir)
    for var in self.save_vars:
      if optimizer or ('Optimizer' not in var.name):
        print(var.name)
        mat = sess.run(var)
        if len(mat.shape) == 1:
          mat = mat[None,:]
        plt.figure()
        try:
          plt.pcolor(mat, cmap='RdBu')
          plt.gca().invert_yaxis()
          plt.colorbar()
          plt.clim(vmin=-1, vmax=1)
          plt.title(var.name)
          plt.savefig(os.path.join(matdir, var.name.replace('/', '-')))
        except ValueError:
          pass
        plt.close()
        del mat
        gc.collect()

  #=============================================================
项目:wub    作者:nanoporetech    | 项目源码 | 文件源码
def plot_pcolor(self, data, title="", xlab="", ylab="", xticks=None, yticks=None, invert_yaxis=False, colormap=plt.cm.Blues, tick_size=5, tick_rotation=90):
        """Plot square heatmap of data matrix.

        :param self: object.
        :param data: 2D array to be plotted.
        :param title: Figure title.
        :param xlab: X axis label.
        :param ylab: Y axis label.
        :param xticks: X axis tick labels..
        :param yticks: Y axis tick labels..
        :param invert_yaxis: Invert Y axis if true.
        :param colormap: matplotlib color map.
        :param tick_size: Font size on tick labels.
        :param tick_rotation: Rotation of tick labels.
        :retuns: None
        :rtype: object
        """
        """
        """

        fig, ax = plt.subplots()
        hm = plt.pcolor(data, cmap=colormap)
        if invert_yaxis:
            ax.invert_yaxis()
        ax.xaxis.tick_top()
        ax.xaxis.set_label_position('top')

        ax.set_xticks(np.arange(data.shape[1]) + 0.5, minor=False)
        ax.set_yticks(np.arange(data.shape[0]) + 0.5, minor=False)

        ax.set_xticklabels(xticks, minor=False, fontsize=tick_size, rotation=tick_rotation)
        ax.set_yticklabels(yticks, minor=False, fontsize=tick_size)
        plt.colorbar(hm)

        self._set_properties_and_close(fig, title, xlab, ylab)
项目:soinn    作者:fukatani    | 项目源码 | 文件源码
def draw_digit(data, n, row, col, title):
    import matplotlib.pyplot as plt
    size = 28
    plt.subplot(row, col, n)
    Z = data.reshape(size,size)   # convert from vector to 28x28 matrix
    Z = Z[::-1,:]                 # flip vertical
    plt.xlim(0,28)
    plt.ylim(0,28)
    plt.pcolor(Z)
    plt.title("title=%s"%(title), size=8)
    plt.gray()
    plt.tick_params(labelbottom="off")
    plt.tick_params(labelleft="off")
项目:mplbplot    作者:pieterdavid    | 项目源码 | 文件源码
def pcolor(first, *args, **kwargs):
    """
    Wrapper around matplotlib.pyplot.pcolor that also takes TH2

    see mplbplot.draw_th2.pcolor or matplotlib.pyplot.pcolor for details
    """
    if isinstance(first, gbl.TH2):
        kwargs["axes"] = plt.gca()
        return draw_th2.pcolor(first, *args, **kwargs)
    else:
        return plt.pcolor(first, *args, **args)
项目:iota    作者:amaneureka    | 项目源码 | 文件源码
def Contour(obj, pcolor=False, contour=True, imshow=False, **options):
    """Makes a contour plot.

    d: map from (x, y) to z, or object that provides GetDict
    pcolor: boolean, whether to make a pseudocolor plot
    contour: boolean, whether to make a contour plot
    imshow: boolean, whether to use pyplot.imshow
    options: keyword args passed to pyplot.pcolor and/or pyplot.contour
    """
    try:
        d = obj.GetDict()
    except AttributeError:
        d = obj

    _Underride(options, linewidth=3, cmap=matplotlib.cm.Blues)

    xs, ys = zip(*d.keys())
    xs = sorted(set(xs))
    ys = sorted(set(ys))

    X, Y = np.meshgrid(xs, ys)
    func = lambda x, y: d.get((x, y), 0)
    func = np.vectorize(func)
    Z = func(X, Y)

    x_formatter = matplotlib.ticker.ScalarFormatter(useOffset=False)
    axes = pyplot.gca()
    axes.xaxis.set_major_formatter(x_formatter)

    if pcolor:
        pyplot.pcolormesh(X, Y, Z, **options)
    if contour:
        cs = pyplot.contour(X, Y, Z, **options)
        pyplot.clabel(cs, inline=1, fontsize=10)
    if imshow:
        extent = xs[0], xs[-1], ys[0], ys[-1]
        pyplot.imshow(Z, extent=extent, **options)
项目:xnmt    作者:neulab    | 项目源码 | 文件源码
def plot_attention(src_words, trg_words, attention_matrix, file_name=None):
  """This takes in source and target words and an attention matrix (in numpy format)
  and prints a visualization of this to a file.
  :param src_words: a list of words in the source
  :param trg_words: a list of target words
  :param attention_matrix: a two-dimensional numpy array of values between zero and one,
    where rows correspond to source words, and columns correspond to target words
  :param file_name: the name of the file to which we write the attention
  """
  fig, ax = plt.subplots()
  # put the major ticks at the middle of each cell
  ax.set_xticks(np.arange(attention_matrix.shape[1]) + 0.5, minor=False)
  ax.set_yticks(np.arange(attention_matrix.shape[0]) + 0.5, minor=False)
  ax.invert_yaxis()

  # label axes by words
  ax.set_xticklabels(trg_words, minor=False)
  ax.set_yticklabels(src_words, minor=False)
  ax.xaxis.tick_top()

  # draw the heatmap
  plt.pcolor(attention_matrix, cmap=plt.cm.Blues, vmin=0, vmax=1)
  plt.colorbar()

  if file_name != None:
    plt.savefig(file_name, dpi=100)
  else:
    plt.show()
  plt.close()
项目:ThinkX    作者:AllenDowney    | 项目源码 | 文件源码
def Contour(obj, pcolor=False, contour=True, imshow=False, **options):
    """Makes a contour plot.

    d: map from (x, y) to z, or object that provides GetDict
    pcolor: boolean, whether to make a pseudocolor plot
    contour: boolean, whether to make a contour plot
    imshow: boolean, whether to use plt.imshow
    options: keyword args passed to plt.pcolor and/or plt.contour
    """
    try:
        d = obj.GetDict()
    except AttributeError:
        d = obj

    _Underride(options, linewidth=3, cmap=matplotlib.cm.Blues)

    xs, ys = zip(*d.keys())
    xs = sorted(set(xs))
    ys = sorted(set(ys))

    X, Y = np.meshgrid(xs, ys)
    func = lambda x, y: d.get((x, y), 0)
    func = np.vectorize(func)
    Z = func(X, Y)

    x_formatter = matplotlib.ticker.ScalarFormatter(useOffset=False)
    axes = plt.gca()
    axes.xaxis.set_major_formatter(x_formatter)

    if pcolor:
        plt.pcolormesh(X, Y, Z, **options)
    if contour:
        cs = plt.contour(X, Y, Z, **options)
        plt.clabel(cs, inline=1, fontsize=10)
    if imshow:
        extent = xs[0], xs[-1], ys[0], ys[-1]
        plt.imshow(Z, extent=extent, **options)
项目:sv_and_isoforms_from_RNAseq    作者:NCBI-Hackathons    | 项目源码 | 文件源码
def heatmap(hgrams, groups, title=None, fname=None, show=False, 
            order='spearman', srrs=None):
  """
  Future versions will include some statistical ordering. Maybe.
  Groups is either ints or strings.
  """
  from scipy.stats import spearmanr
  rows, cols = len(hgrams), max([len(h) for h in hgrams])
  plt.figure(figsize=(6,6))

  # Get group stuff
  uni_groups = list(set(groups))
  ngroups = len(uni_groups)
  group_dict = {n: uni_groups[n] for n in range(ngroups)}
  # Make arr for heatmap
  arr = np.zeros((rows+ngroups-1, cols)) 
  for gro in range(len(uni_groups)):
    for hg in range(len(hgrams)): # Should be same length as groups
      if groups[hg] == uni_groups[gro]:
        arr[hg+gro,:len(hgrams[hg])] = hgrams[hg]
    arr[hg+gro,:len(hgrams[hg])] = [1 if i%2==0 else 0 for i in range(len(hgrams[hg]))]
  arr = np.array(arr)

  # Plotting stuff!
  heatmap = plt.pcolor(arr)
  plt.xlabel('Base number')
  plt.ylabel('Accession number')
  for gr in range(ngroups):
    plt.text(gr*20,0, '%i: %s' %(gr, uni_groups[gr]))
  if srrs is not None:
    plt.x
  if fname is not None:
    plt.savefig(fname)
  if show:
    plt.show()
  return
项目:spatial-reasoning    作者:JannerM    | 项目源码 | 文件源码
def vis_value_map(pred, targ, save_path, title='prediction', share=True):
    # print 'in vis: ', pred.shape, targ.shape
    dim = int(math.sqrt(pred.size))
    if share:
        vmin = min(pred.min(), targ.min())
        vmax = max(pred.max(), targ.max())
    else:
        vmin = None
        vmax = None

    plt.clf()
    fig, (ax0,ax1) = plt.subplots(1,2,sharey=True)
    heat0 = ax0.pcolor(pred.reshape(dim,dim), vmin=vmin, vmax=vmax, cmap=cm.jet)
    ax0.set_title(title, fontsize=5)
    if not share:
        fig.colorbar(heat0)
    heat1 = ax1.pcolor(targ.reshape(dim,dim), vmin=vmin, vmax=vmax, cmap=cm.jet)
    ax1.invert_yaxis()
    ax1.set_title('target')

    fig.subplots_adjust(right=0.8)
    cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
    fig.colorbar(heat1, cax=cbar_ax)

    # print 'saving to: ', fullpath
    plt.savefig(save_path, bbox_inches='tight')
    plt.close(fig)

    # print pred.shape, targ.shape
项目:spatial-reasoning    作者:JannerM    | 项目源码 | 文件源码
def vis_fig(data, save_path, title=None, vmax=None, vmin=None, cmap=cm.jet):
    # print 'in vis: ', pred.shape, targ.shape
    dim = int(math.sqrt(data.size))

    # if share:
    #     vmin = min(pred.min(), targ.min())
    #     vmax = max(pred.max(), targ.max())
    # else:
    #     vmin = None
    #     vmax = None

    plt.clf()
    # fig, (ax0,ax1) = plt.subplots(1,2,sharey=True)
    plt.pcolor(data.reshape(dim,dim), vmin=vmin, vmax=vmax, cmap=cmap)
    plt.xticks([])
    plt.yticks([])
    # ax0.set_title(title, fontsize=5)
    # if not share:
        # fig.colorbar(heat0)
    # heat1 = ax1.pcolor(targ.reshape(dim,dim), vmin=vmin, vmax=vmax, cmap=cm.jet)
    fig = plt.gcf()
    ax = plt.gca()

    if title:
        ax.set_title(title)
    ax.invert_yaxis()

    fig.set_size_inches(4,4)

    # ax1.invert_yaxis()
    # ax1.set_title('target')

    # fig.subplots_adjust(right=0.8)
    # cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
    # fig.colorbar(heat1, cax=cbar_ax)

    # print 'saving to: ', fullpath
    plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0)
    plt.close(fig)

    # print pred.shape, targ.shape
项目:spatial-reasoning    作者:JannerM    | 项目源码 | 文件源码
def visualize_values(mdp, values, policy, filename, title=None):
    states = mdp.states
    # print states
    plt.clf()
    m = max(states, key=lambda x: x[0])[0] + 1
    n = max(states, key=lambda x: x[1])[1] + 1
    data = np.zeros((m,n))
    for i in range(m):
        for j in range(n):
            state = (i,j)
            if type(values) == dict:
                data[i][j] = values[state]
            else:
                # print values[i][j]
                data[i][j] = values[i][j]
            action = policy[state]
            ## if using all_reachable actions, pick the best one
            if type(action) == tuple:
                action = action[0]
            if action != None:
                x, y, w, h = arrow(i, j, action)
                plt.arrow(x,y,w,h,head_length=0.4,head_width=0.4,fc='k',ec='k')
    heatmap = plt.pcolor(data, cmap=plt.get_cmap('jet'))
    plt.colorbar()
    plt.gca().invert_yaxis()

    if title:
        plt.title(title)
    plt.savefig(filename + '.png')
    # print data
项目:snn4hrl    作者:florensacc    | 项目源码 | 文件源码
def log_diagnostics(self, paths):
        progs = [
            path["observations"][-1][-3] - path["observations"][0][-3]
            for path in paths
        ]
        logger.record_tabular('AverageForwardProgress', np.mean(progs))
        logger.record_tabular('MaxForwardProgress', np.max(progs))
        logger.record_tabular('MinForwardProgress', np.min(progs))
        logger.record_tabular('StdForwardProgress', np.std(progs))

        progs_norm = [
            np.linalg.norm(path["observations"][-1][-3:-1] - path["observations"][0][-3:-1])
            for path in paths
            ]
        logger.record_tabular('AverageForwardProgress_norm', np.mean(progs_norm))
        logger.record_tabular('MaxForwardProgress_norm', np.max(progs_norm))
        logger.record_tabular('MinForwardProgress_norm', np.min(progs_norm))
        logger.record_tabular('StdForwardProgress_norm', np.std(progs_norm))
        # now we will grid the space and check how much of it the policy is covering

        # problem with paths of different lenghts: call twice max
        furthest = np.ceil(np.abs(np.max([np.max(path["observations"][:,-3:-1]) for path in paths])))
        print('THE FUTHEST IT WENT COMPONENT-WISE IS', furthest)
        furthest = max(furthest, 10)

        # c_grid = furthest * 10 * 2
        # visitation = np.zeros((c_grid, c_grid))  # we assume the furthest it can go is 100, Check it!!
        # for path in paths:
        #     com_x = np.clip(((np.array(path['observations'][:, -3]) + furthest) * 10).astype(int), 0, c_grid - 1)
        #     com_y = np.clip(((np.array(path['observations'][:, -2]) + furthest) * 10).astype(int), 0, c_grid - 1)
        #     coms = zip(com_x, com_y)
        #     for com in coms:
        #         visitation[com] += 1
        #
        # # if you want to have a heatmap of the visitations
        # plt.figure()
        # plt.pcolor(visitation)
        # t = str(int(time.time()))
        # plt.savefig('data/local/visitation_regular_ant_trpo/visitation_map_' + t)
        #
        # total_visitation = np.count_nonzero(visitation)
        # logger.record_tabular('VisitationTotal', total_visitation)
项目:pyts    作者:johannfaouzi    | 项目源码 | 文件源码
def plot_dtw(x, y, dist='absolute', output_file=None):
    """Plot the optimal warping path between two time series.

    Parameters
    ----------
    x : np.array, shape = [n_features]
        first time series

    y : np.array, shape = [n_features]
        first time series

    dist : str or callable (default = 'absolute')
        cost distance between two real numbers. Possible values:

        - 'absolute' : absolute value of the difference
        - 'square' : square of the difference
        - callable : first two parameters must be real numbers
            and it must return a real number.

    output_file : str or None (default = None)
        if str, save the figure.
    """

    # Check input data
    if not (isinstance(x, np.ndarray) and x.ndim == 1):
        raise ValueError("'x' must be a 1-dimensional np.ndarray.")
    if not (isinstance(y, np.ndarray) and y.ndim == 1):
        raise ValueError("'y' must be a 1-dimensional np.ndarray.")
    if x.size != y.size:
        raise ValueError("'x' and 'y' must have the same size.")

    # Size of x
    x_size = x.size

    # Check parameters
    if not (callable(dist) or dist in ['absolute', 'square']):
        raise ValueError("'dist' must be a callable or 'absolute' or 'square'.")

    D, path = dtw(x, y, dist=dist, return_path=True)

    x_1 = np.arange(x_size + 1)
    z_1 = np.zeros([x_size + 1, x_size + 1])
    for i in range(len(path)):
        z_1[path[i][0], path[i][1]] = 1

    plt.pcolor(x_1, x_1, z_1, edgecolors='k', cmap='Greys')
    plt.xlabel('x', fontsize=20)
    plt.ylabel('y', fontsize=20)

    if output_file is not None:
        plt.savefig(output_file)
项目:phdplot    作者:gz    | 项目源码 | 文件源码
def heatmap(name, data, title, text):
    fig, ax = plt.subplots()
    ticks_font = font_manager.FontProperties(family='Decima Mono')
    plt.style.use([os.path.join(sys.path[0], 'ethplot.mplstyle')])
    LEFT = 0.125
    fig.suptitle(title,
             horizontalalignment='left',
             weight='bold', fontsize=20,
             x=LEFT, y=1)
    t = fig.text(LEFT, 0.92, text,
                 horizontalalignment='left',
                 weight='medium', fontsize=16, color='#555555')

    labels1 = ['PR','HD','SSSP','SCC']
    labels2 = ['PR','HD','SSSP','SCC']
    ax.set_xticklabels(labels1)
    ax.set_yticklabels(labels2)
    ax.set_yticks(np.arange(data.shape[0]) + 0.5)
    ax.set_xticks(np.arange(data.shape[1]) + 0.5)

    ax.tick_params(pad=11)

    plt.setp(ax.get_xticklabels(), fontproperties=ticks_font)
    plt.setp(ax.get_yticklabels(), fontproperties=ticks_font)

    norm = MidpointNormalize(midpoint=1.0)
    c = plt.pcolor(data, cmap = colors, vmin=0.5, vmax=2.5, norm=norm)

    values = data.as_matrix()
    for x in range(data.shape[0]):
        for y in range(data.shape[1]):
            #color = 'white' if values[y][x] > 2.3 else 'black'
            color = 'black'
            plt.text(x + 0.5, y + 0.5, '%.2f' % values[y][x],
                     horizontalalignment='center',
                     verticalalignment='center',
                     color=color,
                     fontproperties=ticks_font)

    colorbar = plt.colorbar(c)
    plt.setp(colorbar.ax.get_yticklabels(), fontproperties=ticks_font)

    plt.savefig(name + ".png", format='png')
    #ppad_inched=0.08 here because otherwise it cuts off the numbers...
    #plt.savefig(name + ".pdf", format='pdf', pad_inches=0.08)
项目:phdplot    作者:gz    | 项目源码 | 文件源码
def heatmap(name, data, title, text):
    fig, ax = plt.subplots()
    ticks_font = font_manager.FontProperties(family='Decima Mono')
    plt.style.use([os.path.join(sys.path[0], 'ethplot.mplstyle')])
    #savefig.pad_inches: 0.08

    LEFT = 0.125
    fig.suptitle(title,
             horizontalalignment='left',
             weight='bold', fontsize=20,
             x=LEFT, y=1)
    t = fig.text(LEFT, 0.92, text,
                 horizontalalignment='left',
                 weight='medium', fontsize=16, color='#555555')

    labels1 = ['PR','HD','SSSP','SCC']
    labels2 = ['PR','HD','SSSP','SCC']

    ax.set_xticklabels(labels1)
    ax.set_yticklabels(labels2)
    ax.set_yticks(np.arange(data.shape[0]) + 0.5)
    ax.set_xticks(np.arange(data.shape[1]) + 0.5)

    ax.tick_params(pad=11)

    plt.setp(ax.get_xticklabels(), fontproperties=ticks_font)
    plt.setp(ax.get_yticklabels(), fontproperties=ticks_font)

    c = plt.pcolor(data, cmap = cm.Greys, vmin=1.0, vmax=2.5)

    values = data.as_matrix()
    for x in range(data.shape[0]):
        for y in range(data.shape[1]):
            color = 'white' if values[y][x] > 2.3 else 'black'
            plt.text(x + 0.5, y + 0.5, '%.2f' % values[y][x],
                     horizontalalignment='center',
                     verticalalignment='center',
                     color=color,
                     fontproperties=ticks_font)

    colorbar = plt.colorbar(c)
    plt.setp(colorbar.ax.get_yticklabels(), fontproperties=ticks_font)

    plt.savefig(name + ".png", format='png')
    #ppad_inched=0.08 here because otherwise it cuts off the numbers...
    #plt.savefig(name + ".pdf", format='pdf', pad_inches=0.08)
项目:spatial-reasoning    作者:JannerM    | 项目源码 | 文件源码
def simulate(model, sim_set):
    # progress = tqdm(total=len(test_set))
    steps_list = []
    count = 0
    for key in tqdm(range(len(sim_set))):
        (state_obs, goal_obs, instruct_words, instruct_inds, targets, mdps) = sim_set[key]
        # progress.update(1)
        # print torch.Tensor(state_obs).long().cuda()
        state = Variable( torch.Tensor(state_obs).long().cuda() )
        objects = Variable( torch.Tensor(goal_obs).long().cuda() )
        instructions = Variable( torch.Tensor(instruct_inds).long().cuda() )
        targets = torch.Tensor(targets)
        # print state.size(), objects.size(), instructions.size()

        preds = model.forward(state, objects, instructions).data.cpu().numpy()
        # print 'sim preds: ', preds.shape

        ## average over all goals
        num_goals = preds.shape[0]
        for ind in range(num_goals):
            # print ind
            mdp = mdps[ind]
            values = preds[ind,:]
            dim = int(math.sqrt(values.size))
            positions = [(i,j) for i in range(dim) for j in range(dim)]
            # print 'dim: ', dim
            values = preds[ind,:].reshape(dim, dim)
            policy = mdp.get_policy(values)

            # plt.clf()
            # plt.pcolor(policy)


            ## average over all start positions
            for start_pos in positions:
                steps = mdp.simulate(policy, start_pos)
                steps_list.append(steps)
                # pdb.set_trace()
                # print 'simulating: ', start_pos, steps
    avg_steps = np.mean(steps_list)
    # print 'avg steps: ', avg_steps, len(steps_list), len(sim_set), num_goals
    return avg_steps
项目:pymatgen-diffusion    作者:materialsvirtuallab    | 项目源码 | 文件源码
def get_3d_plot(self, figsize=(12, 8), type="distinct"):
        """
        Plot 3D self-part or distinct-part of van Hove function, which is specified
        by the input argument 'type'.
        """

        assert type in ["distinct", "self"]

        if type == "distinct":
            grt = self.gdrt.copy()
            vmax = 4.0
            cb_ticks = [0, 1, 2, 3, 4]
            cb_label = "$G_d$($t$,$r$)"
        elif type == "self":
            grt = self.gsrt.copy()
            vmax = 1.0
            cb_ticks = [0, 1]
            cb_label = "4$\pi r^2G_s$($t$,$r$)"

        y = np.arange(np.shape(grt)[1]) * self.interval[-1] / float(
            len(self.interval) - 1)
        x = np.arange(
            np.shape(grt)[0]) * self.timeskip
        X, Y = np.meshgrid(x, y, indexing="ij")

        ticksize = int(figsize[0] * 2.5)

        plt.figure(figsize=figsize, facecolor="w")
        plt.xticks(fontsize=ticksize)
        plt.yticks(fontsize=ticksize)

        labelsize = int(figsize[0] * 3)

        plt.pcolor(X, Y, grt, cmap="jet", vmin=grt.min(), vmax=vmax)
        plt.xlabel("Time (ps)", size=labelsize)
        plt.ylabel("$r$ ($\AA$)", size=labelsize)
        plt.axis([x.min(), x.max(), y.min(), y.max()])

        cbar = plt.colorbar(ticks=cb_ticks)
        cbar.set_label(label=cb_label, size=labelsize)
        cbar.ax.tick_params(labelsize=ticksize)
        plt.tight_layout()

        return plt