Python matplotlib.cm 模块,rainbow() 实例源码

我们从Python开源项目中,提取了以下29个代码示例,用于说明如何使用matplotlib.cm.rainbow()

项目:WebAppEx    作者:karlafej    | 项目源码 | 文件源码
def get_plot(x, y, k, iris=iris):
    k_means = KMeans(n_clusters= k)
    k_means.fit(iris.data) 
    colormap = rainbow(np.linspace(0, 1, k))
    fig = plt.figure()
    splt = fig.add_subplot(1, 1, 1)
    splt.scatter(iris.data[:,x], iris.data[:,y], c = colormap[k_means.labels_], s=40)
    splt.scatter(k_means.cluster_centers_[:,x], k_means.cluster_centers_[:,y], c = 'black', marker='x')
    splt.set_xlabel(iris.feature_names[x])
    splt.set_ylabel(iris.feature_names[y])

    figfile = BytesIO()
    plt.savefig(figfile, format='png')
    figfile.seek(0) 
    figdata_png = base64.b64encode(figfile.getvalue()).decode()
    return figdata_png
项目:LIE    作者:EmbraceLife    | 项目源码 | 文件源码
def plot_scatter(values, cls):
    # Create a color-map with a different color for each class.
    import matplotlib.cm as cm
    cmap = cm.rainbow(np.linspace(0.0, 1.0, num_classes))

    # Create an index with a random permutation to make a better plot.
    idx = np.random.permutation(len(values))

    # Get the color for each sample.
    colors = cmap[cls[idx]]

    # Extract the x- and y-values.
    x = values[idx, 0]
    y = values[idx, 1]

    # Plot it.
    plt.scatter(x, y, color=colors, alpha=0.5)
    plt.show()


# Plot the transfer-values that have been reduced using PCA. There are 3 different colors for the different classes in the Knifey-Spoony data-set. The colors have very large overlap. This may be because PCA cannot properly separate the transfer-values.

# In[41]:
项目:LIE    作者:EmbraceLife    | 项目源码 | 文件源码
def plot_scatter(values, cls):
    # Create a color-map with a different color for each class.
    import matplotlib.cm as cm
    cmap = cm.rainbow(np.linspace(0.0, 1.0, num_classes))

    # Get the color for each sample.
    colors = cmap[cls]

    # Extract the x- and y-values.
    x = values[:, 0]
    y = values[:, 1]

    # Plot it.
    plt.scatter(x, y, color=colors)
    plt.show()


# Plot the transfer-values that have been reduced using PCA. There are 10 different colors for the different classes in the CIFAR-10 data-set. The colors are grouped together but with very large overlap. This may be because PCA cannot properly separate the transfer-values.

# In[35]:
项目:motion-classification    作者:matthiasplappert    | 项目源码 | 文件源码
def _plot_proto_symbol_space(coordinates, target_names, name, args):
    # Reduce to 2D so that we can plot it.
    coordinates_2d = TSNE().fit_transform(coordinates)

    n_samples = coordinates_2d.shape[0]
    x = coordinates_2d[:, 0]
    y = coordinates_2d[:, 1]
    colors = cm.rainbow(np.linspace(0, 1, n_samples))

    fig = plt.figure(1)
    plt.clf()
    ax = fig.add_subplot(111)
    dots = []
    for idx in xrange(n_samples):
        dots.append(ax.plot(x[idx], y[idx], "o", c=colors[idx], markersize=15)[0])
        ax.annotate(target_names[idx],  xy=(x[idx], y[idx]))
    lgd = ax.legend(dots, target_names, ncol=4, numpoints=1, loc='upper center', bbox_to_anchor=(0.5,-0.1))
    ax.grid('on')

    if args.output_dir is not None:
        path = os.path.join(args.output_dir, name + '.pdf')
        print('Saved plot to file "%s"' % path)
        fig.savefig(path, bbox_extra_artists=(lgd,), bbox_inches='tight')
    else:
        plt.show()
项目:Tensorflow-Tutorial    作者:MorvanZhou    | 项目源码 | 文件源码
def plot_with_labels(lowDWeights, labels):
    plt.cla(); X, Y = lowDWeights[:, 0], lowDWeights[:, 1]
    for x, y, s in zip(X, Y, labels):
        c = cm.rainbow(int(255 * s / 9)); plt.text(x, y, s, backgroundcolor=c, fontsize=9)
    plt.xlim(X.min(), X.max()); plt.ylim(Y.min(), Y.max()); plt.title('Visualize last layer'); plt.show(); plt.pause(0.01)
项目:DCN    作者:alexnowakvila    | 项目源码 | 文件源码
def plot_accuracies(self, accuracies, scales=[], mode='train', fig=0):
        plt.figure(fig)
        plt.clf()
        colors = cm.rainbow(np.linspace(0, 1, len(scales)))
        l = []
        names = [str(sc) for sc in scales]
        for i, acc in enumerate(accuracies):
            ll, = plt.plot(range(len(acc)), acc, color=colors[i])
            l.append(ll)
        plt.ylabel('accuracy')
        plt.legend(l, names, loc=2, prop={'size': 6})
        if mode == 'train':
            plt.xlabel('iterations')
        else:
            plt.xlabel('iterations x 1000')
        path = os.path.join(self.path, 'accuracies_{}.png'.format(mode))
        plt.savefig(path)
项目:DCN    作者:alexnowakvila    | 项目源码 | 文件源码
def plot_accuracies(self, accuracies, scales=[], mode='train', fig=0):
        plt.figure(fig)
        plt.clf()
        colors = cm.rainbow(np.linspace(0, 1, len(scales)))
        l = []
        names = [str(sc) for sc in scales]
        for i, acc in enumerate(accuracies):
            ll, = plt.plot(range(len(acc)), acc, color=colors[i])
            l.append(ll)
        plt.ylabel('accuracy')
        plt.legend(l, names, loc=2, prop={'size': 6})
        if mode == 'train':
            plt.xlabel('iterations')
        else:
            plt.xlabel('iterations x 1000')
        path = os.path.join(self.path, 'accuracies_{}.png'.format(mode))
        plt.savefig(path)
项目:DCN    作者:alexnowakvila    | 项目源码 | 文件源码
def plot_norm_points(self, Inputs_N, e, Perms, scales, fig=1):
        input = Inputs_N[0][0].data.cpu().numpy()
        e = torch.sort(e, 1)[0][0].data.cpu().numpy()
        Perms = [perm[0].data.cpu().numpy() for perm in Perms]
        plt.figure(fig)
        plt.clf()
        ee = e.copy()
        for i, perm in enumerate(Perms):
            plt.subplot(1, len(Perms), i + 1)
            colors = cm.rainbow(np.linspace(0, 1, 2 ** (scales - i)))
            perm = perm[np.where(perm > 0)[0]] - 1
            points = input[perm]
            e_scale = ee[perm]
            for node in xrange(2 ** (scales - i)):
                ind = np.where(e_scale == node)[0]
                pts = points[ind]
                plt.scatter(pts[:, 0], pts[:, 1], c=colors[node])
            ee //= 2
        path = os.path.join(self.path, 'visualize_example.png')
        plt.savefig(path)
项目:ChainConsumer    作者:Samreay    | 项目源码 | 文件源码
def get_colormap(self, num, scale=0.7):  # pragma: no cover
        color_list = self.get_formatted(cm.rainbow(np.linspace(0, 1, num)))
        scales = scale + (1 - scale) * np.abs(1 - np.linspace(0, 2, num))
        scaled = [self.scale_colour(c, s) for c, s in zip(color_list, scales)]
        return scaled
项目:WebAppEx    作者:karlafej    | 项目源码 | 文件源码
def getPlot(self, params):
        k = int(params['cluster'])
        x = int(params['x_axis'])
        y = int(params['y_axis'])
        k_means = KMeans(n_clusters= k )
        k_means.fit(self.iris.data) 
        colormap = rainbow(np.linspace(0, 1, k))
        fig = plt.figure()
        splt = fig.add_subplot(1, 1, 1)
        splt.scatter(self.iris.data[:,x], self.iris.data[:,y], c = colormap[k_means.labels_], s=40)
        splt.scatter(k_means.cluster_centers_[:,x], k_means.cluster_centers_[:,y], c = 'black', marker='x')
        splt.set_xlabel(self.iris.feature_names[x])
        splt.set_ylabel(self.iris.feature_names[y])
        return fig
项目:SecuML    作者:ANSSI-FR    | 项目源码 | 文件源码
def colors(num):
    colors = cm.rainbow(np.linspace(0, 1, num))
    colors = map(rgb2hex, colors)
    return colors
项目:accpy    作者:kramerfelix    | 项目源码 | 文件源码
def trackplot(ax, data, turns=False, xy=False, fs=[16, 9], showlost=False,
              everyxturn=[0, 1], ms=1):
    x, y = xy
    colors = rainbow(linspace(0, 1, data['Particles'][0]))
    IDs = data['allIDs'].copy()
    if not showlost:
        IDs = delete(IDs, data['lostIDs'])
    for part, col in zip(IDs, colors):
        i, f = everyxturn
        xdat, ydat = data[x][i::f, part], data[y][i::f, part]
        ax.plot(xdat*1e3, ydat*1e3, '.', color=col, ms=ms)
    ax.set_xlabel(r'$x$ / (mm)')
    ax.set_ylabel(r'$x^\prime$ / (mrad)')
    return
项目:accpy    作者:kramerfelix    | 项目源码 | 文件源码
def showbun(datadict):
    x = ['', '', '', '', '', '', 'x', 'y', 't']
    y = ['x', 'y', 't', 'xp', 'yp', 'p', 'xp', 'yp', 'p']
    colors = cm.rainbow(linspace(0, 1, datadict['Particles'][0]))
    for i, (x, y) in enumerate(zip(x, y)):
        subplot(3, 3, i+1)
        if x == '':
            [plot(datadict[y][part, ], '.', color=col) for part, col in enumerate(colors)]
            xlabel('Pass')
        else:
            [plot(datadict[x][part, ], datadict[y][part, ], '.', color=col) for part, col in enumerate(colors)]
            xlabel(x)
        ylabel(y)
    tight_layout()
    return
项目:ml-deepranking    作者:urakozz    | 项目源码 | 文件源码
def plot_with_labels(low_dim_embs, labels, filename='tsne_c5s5r5.png'):
    assert low_dim_embs.shape[0] >= len(labels), "More labels than embeddings"
    colors = cm.rainbow(np.linspace(0, 1, len(np.unique(np.array(labels)))))
    plt.figure(figsize=(18, 18))  #in inches
    for i, label in enumerate(labels):
        x, y = low_dim_embs[i,:]
        plt.scatter(x, y, color=colors[label])
        plt.annotate(label,
                     xy=(x, y),
                     xytext=(5, 2),
                     textcoords='offset points',
                     ha='right',
                     va='bottom')

    plt.savefig(filename)
项目:dlsd    作者:ahartens    | 项目源码 | 文件源码
def create_colors_array_of_length(self, length):
        self.colors = cm.rainbow(np.linspace(0, 1, length))
项目:act-rte-inference    作者:DeNeutoy    | 项目源码 | 文件源码
def single_mean_with_variance(stats, title, save=False):

    """ Plot the mean number of ACT steps for a single run
    with variance bounds above and below   """
    labels = ["Training Set", "Validation Set"]
    datasets = [("train_step_mean", "train_step_var"),("val_step_mean", "val_step_var")]
    leg = []
    fig = plt.figure()
    colours = cm.rainbow(np.linspace(0,1,2))

    for key, label, color in zip(datasets, labels, colours):
        means = np.array(stats[key[0]]) + 1.0
        vars = np.array(stats[key[1]])

        plotted, =plt.plot(stats["epoch"],means, label=label, color=color)
        leg.append(plotted)
        upper = means + np.sqrt(vars)
        lower = means - np.sqrt(vars)
        plt.fill_between(stats["epoch"],lower, upper, alpha=0.3, color=color)

    plt.legend(handles=leg)

    ax = plt.gca()
    ax.set_title(title + ": Step Penalty = " + str(stats["config"][0].step_penalty))
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Mean ACT Steps")
    ax.set_ylim([1,6])
    if save:
        plt.savefig("mean_and_variance" + str(stats["config"][0].step_penalty) +".png" )
    else:
        plt.show()
项目:act-rte-inference    作者:DeNeutoy    | 项目源码 | 文件源码
def mean_average_steps(all_loaded_stats, title):

    """ Plot all runs lightly, with colours corresponding to different step penalties.
        Plot the mean per step_penalty in bold. """

    step_params = list(set([run["config"][0].step_penalty for run in all_loaded_stats]))
    colours = cm.rainbow(np.linspace(0,1,len(step_params)))
    fig = plt.figure()
    mean_dict = defaultdict(list)

    # plot all runs lightly and accumulate runs per step_penalty parameter
    for run in all_loaded_stats:

        c = colours[step_params.index(run["config"][0].step_penalty)]
        data = np.array(run["val_step_mean"]) + 1.0                 #TODO: fix this
        plt.plot(run["epoch"], data ,color=c ,alpha=0.2)
        mean_dict[run["config"][0].step_penalty].append(data)

    # now plot the mean values in bold
    for key, value in mean_dict.items():

        c = colours[step_params.index(key)]
        data = np.vstack(value).mean(0)
        plt.plot(loaded_stats[0]["epoch"],data, color=c, alpha=1.0, linewidth=2.0)

    ax = plt.gca()
    ax.set_title(title)
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Mean ACT steps")
    plt.show()
项目:hyperspectral-framework    作者:mihailoobrenovic    | 项目源码 | 文件源码
def vis_data(data,classes):

    X_embedded = TSNE(n_components=2, perplexity=40, verbose=2).fit_transform(data)
    plt.figure()
    colors = cm.rainbow(np.linspace(0, 1, 17))
    for i in range(17):
        ind = np.where(classes==i)
        plt.scatter(X_embedded[ind,0],X_embedded[ind,1],color = colors[i],marker ='x',label = i)
    plt.legend()

# Raw data
项目:Python-Data-Analysis-Learning-Notes    作者:Asurada2015    | 项目源码 | 文件源码
def plot_with_labels(lowDWeights, labels):
    plt.cla(); X, Y = lowDWeights[:, 0], lowDWeights[:, 1]
    for x, y, s in zip(X, Y, labels):
        c = cm.rainbow(int(255 * s / 9)); plt.text(x, y, s, backgroundcolor=c, fontsize=9)
    plt.xlim(X.min(), X.max()); plt.ylim(Y.min(), Y.max()); plt.title('Visualize last layer'); plt.show(); plt.pause(0.01)
项目:TWEleReceipt    作者:hsucw    | 项目源码 | 文件源码
def plot3D(data, taxid_list):
    colors = cm.rainbow(np.linspace(0, 1, len(taxid_list)))

    fig = plt.figure(figsize=(12,9))
    ax = fig.add_subplot(111, projection='3d')
    fig.suptitle("NCTU receipts")

    c_index = 0
    for taxid in taxid_list:

        (x, y, z) = data[taxid]
        print "{} {}".format(taxid, colors[c_index])
        #print len(X[taxid]), len(Y[taxid]), len(Z[taxid])
        ax.scatter(np.log(x),y,z,c=colors[c_index],marker='o',\
        s=20, alpha=0.2, edgecolors='none')
        #x2, y2, _ = proj3d.proj_transform(x,y,z, ax.get_proj())
        #label = plt.annotate(\
        #taxid,\
        #xy = (x2, y2), xytext = (-20, 20),\
        #textcoords = 'offset points', ha = 'right', va = 'bottom',\
        #bbox = dict(boxstyle = 'round,pad=0.5', fc = 'yellow', alpha = 0.1),\
        #arrowprops = dict(arrowstyle = '->', connectionstyle = 'arc3,rad=0'))

        c_index += 1


    ax.set_xlabel('num')
    #ax.set_xscale('log',nonposx='clip')
    ax.set_ylabel('med')
    ax.set_zlabel('ratio')
    ax.zaxis.set_major_formatter(plt.FuncFormatter(\
        lambda x, loc: "{:,}".format(int(x))))
    #fig.canvas.mpl_connect('button_release_event', update_position)
    plt.show()
项目:medium_posts    作者:leapingllamas    | 项目源码 | 文件源码
def show_kmeans(points, centers=None):
    #http://stackoverflow.com/questions/9401658/matplotlib-animating-a-scatter-plot
    xs=[]
    ys=[]
    c=[]
    wts=[]
    m=[]
    colors = list(iter(cm.rainbow(np.linspace(0, 1, len(centers)))))
    for p in points:
        xs.append(p['coords'][0])
        ys.append(p['coords'][1])
        c.append(colors[p['c']])
        #wts.append(40+p['w'])
        wts.append(3)
        m.append('o')

    if centers:
        for i,cl in enumerate(centers):
            xs.append(cl['coords'][0])
            ys.append(cl['coords'][1])
            c.append('yellow')
            wts.append(500)
            m.append('*')

    for _s, _c, _x, _y,_sz in zip(m, c, xs, ys,wts):
        pyplot.scatter(_x, _y, marker=_s, c=_c,s=_sz, lw = 0)

    pyplot.show()
项目:DCN    作者:alexnowakvila    | 项目源码 | 文件源码
def plot_classes(self, points, clusters, e, fig=0):
        e = e[0].data.cpu().numpy()
        points = points[0]
        plt.figure(fig)
        plt.clf()
        colors = cm.rainbow(np.linspace(0, 1, clusters))
        for cl in range(clusters):
            ind = np.where(e == cl)[0]
            pts = points[ind]
            plt.scatter(pts[:, 0], pts[:, 1], c=colors[cl])
        plt.title('clustering')
        path = os.path.join(self.path, 'clustering_ex.png'.format(clusters))
        plt.savefig(path)
项目:DCN    作者:alexnowakvila    | 项目源码 | 文件源码
def plot_example(self, x, y, clusters, length):
        plt.figure(0)
        plt.clf()
        colors = cm.rainbow(np.linspace(0, 1, clusters))
        for c in range(clusters):
            ind = np.where(y == c)[0]
            plt.scatter(x[ind, 0], x[ind, 1], c=colors[c])
        path = '/home/anowak/DynamicProgramming/DP/plots/example.png'
        plt.savefig(path)
项目:simulator    作者:P2PSP    | 项目源码 | 文件源码
def draw_buffer(self):
        self.buffer_figure, self.buffer_ax = plt.subplots()
        self.lineIN, = self.buffer_ax.plot([1]*2, [1]*2, color='#000000', ls="None", label="IN", marker='o', animated=True)
        self.lineOUT, = self.buffer_ax.plot([1]*2, [1]*2, color='#CCCCCC', ls="None", label="OUT", marker='o', animated=True)
        self.buffer_figure.suptitle("Buffer Status", size=16)
        plt.legend(loc=2, numpoints=1)
        total_peers = self.number_of_monitors + self.number_of_peers + self.number_of_malicious
        self.buffer_colors = cm.rainbow(np.linspace(0, 1, total_peers))
        plt.axis([0, total_peers+1, 0, self.get_buffer_size()])
        plt.xticks(range(0, total_peers+1, 1))
        self.buffer_order = {}
        self.buffer_index = 1
        self.buffer_labels = self.buffer_ax.get_xticks().tolist()
        plt.grid()
        self.buffer_figure.canvas.draw()
项目:tf-example-models    作者:aakhundov    | 项目源码 | 文件源码
def plot_clustered_data(points, c_means, c_assignments):
    """Plots the cluster-colored data and the cluster means"""
    colors = cm.rainbow(np.linspace(0, 1, CLUSTERS))

    for cluster, color in zip(range(CLUSTERS), colors):
        c_points = points[c_assignments == cluster]
        plt.plot(c_points[:, 0], c_points[:, 1], ".", color=color, zorder=0)
        plt.plot(c_means[cluster, 0], c_means[cluster, 1], ".", color="black", zorder=1)

    plt.show()


# PREPARING DATA

# generating DATA_POINTS points from a GMM with CLUSTERS components
项目:accpy    作者:kramerfelix    | 项目源码 | 文件源码
def tuneplot(ax1, ax2, data, particleIDs='allIDs', integer=1, addsub=add,
             clipint=True, showlost=False, QQ='Qx', ms=1, clip=[0], showfit=False):
    particleIDs = data[particleIDs]
    if not showlost:
    lost = data['lost'][:, 0]
    clip = concatenate([clip, lost])
    particleIDs = delete(particleIDs, clip)
    Q = addsub(integer, data[QQ][particleIDs])
    if clipint:
        zeroQ = find(logical_or(logical_or(Q == 0.0, Q == 1.0), Q == 0.5))
        if len(zeroQ) > 0:  # trim reference particle with zero tune
            Q = delete(Q, zeroQ)
            particleIDs = delete(particleIDs, zeroQ)
    Qmin, Qmax = nanmin(Q), nanmax(Q)
    Qdif = Qmax - Qmin
    if Qdif == 0.0:
        Qmin -= Qmin/1e4
        Qmax += Qmax/1e4
        Qdif = Qmax - Qmin
    colors = cool((Q - Qmin) / Qdif)
    for i, ID in enumerate(particleIDs):
        ax1.plot(data['x'][:, ID]*1e3, data['xp'][:, ID]*1e3, '.', c=colors[i], ms=ms)
    if showlost:
        for ID in lost:
            ax1.plot(data['x'][:, ID]*1e3, data['xp'][:, ID]*1e3, '.', c='gray', ms=ms)
    sm = ScalarMappable(cmap=rainbow, norm=Normalize(vmin=Qmin, vmax=Qmax))
    sm._A = []
    ax1.set_xlabel(r'Position $x$ / (mm)')
    ax1.set_ylabel(r'Angle $x^\prime$ / (mrad)')
    emittance = data['A'][particleIDs]/pi
    action = emittance/2

    # tune shift with action
    fitfun = lambda x, a, b: a + b*x
    popt, pcov = curve_fit(fitfun, action, Q)
    perr = sqrt(diag(pcov))
    action2 = linspace(nanmin(action), nanmax(action), 1000)
    fit1 = fitfun(action2, *popt)
    print(popt[1]*1e-6*1250)

    for i, ID in enumerate(particleIDs):
        ax2.plot(action[i]*1e6, Q[i], 'o', c=colors[i], ms=ms + 1)
    if showfit:
    ax2.plot(action2*1e6, fit1, '-k', lw=1, label=r'fit with $TSWA=${:.4}$\pm${:.1} (kHz mm$^-$$^2$mrad$^-$$^2$)'.format(popt[1]*1e-6*1250, perr[1]*1e-6*1250))
#    leg = ax2.legend()
#    leg.get_frame().set_alpha(0)
    ax2.set_ylim([Qmin, Qmax])
#    ax2.yaxis.tick_right()
    ax2.set_ylabel(r'Fractional Tune $dQ$')
#    ax2.yaxis.set_label_position('right')
    ax2.set_xlabel(r'Action $J_x$ / (mm$\cdot$mrad)')
    tight_layout()
    return
项目:accpy    作者:kramerfelix    | 项目源码 | 文件源码
def trackplot(datadict, turns=False, xy=False, fs=[16, 9], ax=False):
    if not ax:
        fig = figure(figsize=fs)
    if xy:
        if not ax:
            ax = fig.add_subplot(111)
        x, y = xy
        colors = cm.rainbow(linspace(0, 1, datadict['Particles'][0]))
        if turns:
            [ax.plot(datadict[x][:turns, part], datadict[y][:turns, part], '.', color=col) for part, col in enumerate(colors)]
        else:
            for part, col in enumerate(colors):
                missing = []
                try:
                    ax.plot(datadict[x][:, part], datadict[y][:, part], '.', color=col)
                except:
                    missing.append(part)
            print('missing particles: ', missing)
        ax.set_xlabel(x)
        ax.set_ylabel(y)
        #tight_layout()
        return
    try:  # centroid watch point
        x = ['Pass', 'Pass', 'Pass', 'Pass', 'Pass', 'Pass', 'Cx', 'Cy', 'dCt']
        y = ['Cx', 'Cy', 'dCt', 'Cxp', 'Cyp', 'Cdelta', 'Cxp', 'Cyp', 'Cdelta']
        for i, (x, y) in enumerate(zip(x, y)):
            subplot(3, 3, i+1)
            if turns:
                plot(datadict[x][:turns, ], datadict[y][:turns, ], '.')
            else:
                plot(datadict[x][:, ], datadict[y][:, ], '.')
            xlabel(x)
            ylabel(y)
    except:  # coordinate watch point
        x = ['t', 't', 't', 't', 't', 't', 'x', 'y', 'dt']
        y = ['x', 'y', 'dt', 'xp', 'yp', 'p', 'xp', 'yp', 'p']
        colors = cm.rainbow(linspace(0, 1, datadict['Particles'][0]))
        for i, (x, y) in enumerate(zip(x, y)):
            subplot(3, 3, i+1)
            if turns:
                if x == 't':
                    [plot(datadict[y][:turns, part], '.', color=col) for part, col in enumerate(colors)]
                    xlabel('Pass')
                else:
                    [plot(datadict[x][:turns, part], datadict[y][:turns, part], '.', color=col) for part, col in enumerate(colors)]
                    xlabel(x)
            else:
                if x == 't':
                    [plot(datadict[y][:, part], '.', color=col) for part, col in enumerate(colors)]
                    xlabel('Pass')
                else:
                    [plot(datadict[x][:, part], datadict[y][:, part], '.', color=col) for part, col in enumerate(colors)]
                    xlabel(x)
            ylabel(y)
            tight_layout()
    return
项目:song-embeddings    作者:brad-ross-35    | 项目源码 | 文件源码
def plot_embedding(embed, labels, plot_type='t-sne', title="", tsne_params={}, save_path=None, 
                   legend=True, label_dict=None, label_order=None, legend_outside=False, alpha=0.7):
    """
    Projects embedding onto two dimensions, colors according to given label
    @param embed:      embedding matrix
    @param labels:     array of labels for the rows of embed
    @param title:      title of plot
    @param save_path:  path of where to save
    @param legend:     bool to show legend
    @param label_dict: dict that maps labels to real names (eg. {0:'rock', 1:'edm'})


    """
    plt.figure()
    N = len(set(labels))
    colors = cm.rainbow(np.linspace(0, 1, N))
    scaled_embed = scale(embed)

    if plot_type == 'pca':
        pca = PCA(n_components=2)
        pca.fit(scaled_embed)
        #note: will take a while if emebdding is large
        comp1, comp2 = pca.components_
        comp1, comp2 = embed.dot(comp1), embed.dot(comp2)    

    if plot_type == 't-sne':
        tsne = TSNE(**tsne_params)
        comp1, comp2 = tsne.fit_transform(scaled_embed).T

    unique_labels = list(set(labels))

    if label_order is not None:
        unique_labels = sorted(unique_labels, key=lambda l: label_order.index(label_dict[l]))
    #genre->indices of that genre (so for loop will change colors)
    l_dict = {i:np.array([j for j in range(len(labels)) if labels[j] == i]) for i in unique_labels}
    for i in range(N):
        l = unique_labels[i]
        color = colors[i]

        #just use the labels of g as the labels
        plt.scatter(comp1[l_dict[l]], comp2[l_dict[l]],
                    color=color, label=label_dict[l], alpha=alpha)

    plt.title(title)
    if legend:
        if N >= 10 or legend_outside:
            lgd = plt.legend(bbox_to_anchor=(1.01, 1), loc='upper left')
        else:
            lgd = plt.legend(loc='best')
    if save_path != None:
        plt.savefig(save_path, bbox_extra_artists=(lgd,), bbox_inches='tight')
项目:Smart-Meter-Experiment-ML-Revisited    作者:felgueres    | 项目源码 | 文件源码
def plot_behavior_cluster(centroids, num_clusters):
    '''
    Plots computed clusters.

    Parameters
    ----------

    Centroids : array
        Predicted centroids of clusters.

    num_clusters: int
        Number of clusters.

    Returns
    -------

    Plot : matplotlib.lines.Line2D
        Figure.

    '''

    # Figure has all clusters on same plot.

    fig = plt.figure(figsize=(10,7))
    ax = fig.add_subplot(1,1,1)

    # Set colors.
    colors = cm.rainbow(np.linspace(0, 1, num_clusters))

    # Plot cluster and corresponding color.
    for cluster, color in enumerate(colors, start =1):

        ax.plot(centroids[cluster-1], c = color, label = "Cluster %d" % cluster)

    # Format figure.
    ax.set_title("Centroids of consumption pattern of clusters, where k = %d" % num_clusters, fontsize =14, fontweight='bold')
    ax.set_xlim([0, 24])
    ax.set_xticks(range(0, 25, 6))
    ax.set_xlabel("Time (h)")
    ax.set_ylabel("Consumption (kWh)")
    leg = plt.legend(frameon = True, loc = 'upper left', ncol =2, fontsize = 12)
    leg.get_frame().set_edgecolor('b')

    plt.show()