Python matplotlib.pyplot 模块,rc() 实例源码

我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用matplotlib.pyplot.rc()

项目:abcpy    作者:eth-cscs    | 项目源码 | 文件源码
def plot(samples, path = None, true_value = 5, title = 'ABC posterior'): 
    Bayes_estimate = np.mean(samples, axis = 0)
    theta = true_value
    xmin, xmax = max(samples[:,0]), min(samples[:,0])
    positions = np.linspace(xmin, xmax, samples.shape[0])
    gaussian_kernel = gaussian_kde(samples[:,0].reshape(samples.shape[0],))
    values = gaussian_kernel(positions)
    plt.figure()
    plt.plot(positions,gaussian_kernel(positions))
    plt.plot([theta, theta],[min(values), max(values)+.1*(max(values)-min(values))])
    plt.plot([Bayes_estimate, Bayes_estimate],[min(values), max(values)+.1*(max(values)-min(values))])
    plt.ylim([min(values), max(values)+.1*(max(values)-min(values))])
    plt.xlabel(r'$\theta$')
    plt.ylabel('density')
    #plt.xlim([0,1])
    plt.rc('axes', labelsize=15) 
    plt.legend(loc='best', frameon=False, numpoints=1)
    font = {'size'   : 15}
    plt.rc('font', **font)
    plt.title(title)
    if path is not None :
        plt.savefig(path)
    return plt
项目:activity-browser    作者:LCA-ActivityBrowser    | 项目源码 | 文件源码
def __init__(self, parent, mlca, width=6, height=6, dpi=100):
        figure = Figure(figsize=(width, height), dpi=dpi, tight_layout=True)
        axes = figure.add_subplot(121)

        super(LCAProcessContributionPlot, self).__init__(figure)
        self.setParent(parent)

        method = 0  # TODO let user choose the LCIA method
        tc = mlca.top_process_contributions(method=method, limit=5, relative=True)
        df_tc = pd.DataFrame(tc)
        df_tc.columns = [format_activity_label(a) for a in tc.keys()]
        df_tc.index = [format_activity_label(a, style='pl') for a in df_tc.index]
        plot = df_tc.T.plot.barh(
            stacked=True,
            figsize=(6, 6),
            cmap=plt.cm.nipy_spectral_r,
            ax=axes
        )
        plot.tick_params(labelsize=8)
        axes.legend(loc='center left', bbox_to_anchor=(1, 0.5))
        plt.rc('legend', **{'fontsize': 8})
        self.setMinimumSize(self.size())
项目:activity-browser    作者:LCA-ActivityBrowser    | 项目源码 | 文件源码
def __init__(self, parent, mlca, width=6, height=6, dpi=100):
        figure = Figure(figsize=(width, height), dpi=dpi, tight_layout=True)
        axes = figure.add_subplot(121)

        super(LCAElementaryFlowContributionPlot, self).__init__(figure)
        self.setParent(parent)

        method = 0  # TODO let user choose the LCIA method
        tc = mlca.top_elementary_flow_contributions(method=method, limit=5, relative=True)
        df_tc = pd.DataFrame(tc)
        df_tc.columns = [format_activity_label(a) for a in tc.keys()]
        df_tc.index = [format_activity_label(a, style='bio') for a in df_tc.index]
        plot = df_tc.T.plot.barh(
            stacked=True,
            figsize=(6, 6),
            cmap=plt.cm.nipy_spectral_r,
            ax=axes
        )
        plot.tick_params(labelsize=8)
        axes.legend(loc='center left', bbox_to_anchor=(1, 0.5))
        plt.rc('legend', **{'fontsize': 8})
        self.setMinimumSize(self.size())
项目:NuGridPy    作者:NuGrid    | 项目源码 | 文件源码
def lifetime(self,label=""):
        '''
            Calculate stellar lifetime till first TP 
            dependent of initial mass
        '''
                plt.rcParams.update({'font.size': 20})
                plt.rc('xtick', labelsize=20)
                plt.rc('ytick', labelsize=20)
                t0_model=self.set_find_first_TP()
                m=self.run_historydata
                i=0
        age=[]
        mass=[]
                for case in m:
            ##lifetime till first TP
            age.append(np.log10(case.get("star_age")[t0_model[i]]))
            mass.append(case.get("star_mass")[0])
            i+=1
        plt.plot(mass,age,"*",markersize=10,linestyle="-",label=label)
        plt.xlabel('star mass $[M_{\odot}]$',fontsize=20)
        plt.ylabel('Logarithmic stellar lifetime',fontsize=20)
项目:scipyplot    作者:robertocalandra    | 项目源码 | 文件源码
def niceFigure(useLatex=True):
    from matplotlib import rcParams
    import matplotlib.pyplot as plt
    # rcParams.update({'figure.autolayout': True})
    if useLatex is True:
        plt.rc('text', usetex=True)
        plt.rcParams['text.latex.preamble'] = [r"\usepackage{amsmath}"]
    rcParams['xtick.direction'] = 'out'
    rcParams['ytick.direction'] = 'out'
    rcParams['xtick.major.width'] = 1
    rcParams['ytick.major.width'] = 1
    #
    # cbar.outline.set_edgecolor('black')
    # cbar.outline.set_linewidth(1)
    #
    return 0
项目:xdesign    作者:tomography    | 项目源码 | 文件源码
def plot_mtf(faxis, MTF, labels=None):
    """Plots the MTF. Returns the figure reference."""
    fig_lineplot = plt.figure()
    plt.rc('axes', prop_cycle=PLOT_STYLES)

    for i in range(0, MTF.shape[0]):
        plt.plot(faxis, MTF[i, :])

    plt.xlabel('spatial frequency [cycles/length]')
    plt.ylabel('Radial MTF')
    plt.gca().set_ylim([0, 1])

    if labels is not None:
        plt.legend([str(n) for n in labels])
    plt.title("Modulation Tansfer Function for various angles")

    return fig_lineplot
项目:OTC3D    作者:tiffanyts    | 项目源码 | 文件源码
def scatter3d(self,title='',size=40,model=[]):
        """Returns a scatterplot of the pdcoord. Useful for visualizing 3D surface data """
        font = {'weight' : 'medium',
                'size'   : 22}
        #plt.rc('font', **font)
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d'); ax.pbaspect = [1, 1, 1] #always need pbaspect
        ax.set_title(title)
        p = ax.scatter(list(self.data.x), list(self.data.y),list(self.data.z),c = list(self.data.v),edgecolors='none', s=size, marker = ",", cmap ='jet')
        ax.view_init(elev=90, azim=-89)
        ax.set_xlabel('X axis'); ax.set_ylabel('Y axis'); ax.set_zlabel('Z axis')
        fig.colorbar(p)

        plt.draw()
        try:
            vertices = [(vertex.X(), vertex.Y(),vertex.Z()) for vertex in pyliburo.py3dmodel.fetch.vertex_list_2_point_list(pyliburo.py3dmodel.fetch.topos_frm_compound(model)["vertex"])]
            V1,V2,V3 = zip(*vertices)            
            p = ax.plot_wireframe(V1,V2,V3 )

        except TypeError:
            pass        

        return fig
项目:route-plotter    作者:perimosocordiae    | 项目源码 | 文件源码
def _setup_figure(bg_img, bg_extent, scale=1.0):
  plt.rc('figure', autolayout=False)  # turn off tight_layout
  dpi = plt.rcParams.get('figure.dpi', 100.0)
  fig = plt.figure(dpi=dpi, frameon=False)

  # scale the figure to fit the bg image
  bg_height, bg_width = bg_img.shape[:2]
  fig.set_size_inches(bg_width / dpi * scale, bg_height / dpi * scale)

  ax = fig.add_axes([0, 0, 1, 1])
  ax.set_axis_off()
  ax.xaxis.set_major_locator(plt.NullLocator())
  ax.yaxis.set_major_locator(plt.NullLocator())
  ax.imshow(bg_img, zorder=0, extent=bg_extent, cmap='Greys_r', aspect='auto')
  ax.autoscale(False)
  ax.margins(0, 0)
  return fig, ax
项目:Efficient-Dynamic-Batching    作者:jsuarez5341    | 项目源码 | 文件源码
def prettyPlot(samps, dat, hid):
   fig, ax = plt.subplots()
   sz = 18
   plt.rc('xtick', labelsize=sz)
   plt.rc('ytick', labelsize=sz)

   ax.set_xticklabels([1]+samps, fontsize=sz)
   ax.set_yticklabels([1]+samps[::-1], fontsize=sz)

   ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
   ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

   ax.set_xlabel('Number of Experts', fontsize=sz+2)
   ax.set_ylabel('Minibatch Size', fontsize=sz+2)
   ax.set_title('MOE Cell Speedup Factor', fontsize=sz+4)

   #Show cell values
   for i in range(len(samps)):
      for j in range(len(samps)):
         ax.text(i, j, str(dat[i,j])[:4], ha='center', va='center', fontsize=sz, color='white')

   plt.imshow(cellTimes, cmap='viridis', norm=colors.LogNorm(vmin=cellTimes.min(), vmax=cellTimes.max()))
   plt.show()
项目:kernel-gof    作者:wittawatj    | 项目源码 | 文件源码
def set_default_matplotlib_options():
    # font options
    font = {
    #     'family' : 'normal',
        #'weight' : 'bold',
        'size'   : 30
    }
    matplotlib.rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})


    # matplotlib.use('cairo')
    matplotlib.rc('text', usetex=True)
    matplotlib.rcParams['text.usetex'] = True
    plt.rc('font', **font)
    plt.rc('lines', linewidth=3, markersize=10)
    # matplotlib.rcParams['ps.useafm'] = True
    # matplotlib.rcParams['pdf.use14corefonts'] = True

    matplotlib.rcParams['pdf.fonttype'] = 42
    matplotlib.rcParams['ps.fonttype'] = 42
项目:KATE    作者:hugochan    | 项目源码 | 文件源码
def plot_tsne(doc_codes, doc_labels, classes_to_visual, save_file):
    # markers = ["D", "p", "*", "s", "d", "8", "^", "H", "v", ">", "<", "h", "|"]
    markers = ["o", "v", "8", "s", "p", "*", "h", "H", "+", "x", "D"]
    plt.rc('legend',**{'fontsize':30})
    classes_to_visual = list(set(classes_to_visual))
    C = len(classes_to_visual)
    while True:
        if C <= len(markers):
            break
        markers += markers

    class_ids = dict(zip(classes_to_visual, range(C)))

    if isinstance(doc_codes, dict) and isinstance(doc_labels, dict):
        codes, labels = zip(*[(code, doc_labels[doc]) for doc, code in doc_codes.items() if doc_labels[doc] in classes_to_visual])
    else:
        codes, labels = doc_codes, doc_labels

    X = np.r_[list(codes)]
    tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000)
    np.set_printoptions(suppress=True)
    X = tsne.fit_transform(X)

    plt.figure(figsize=(10, 10), facecolor='white')

    for c in classes_to_visual:
        idx = np.array(labels) == c
        # idx = get_indices(labels, c)
        plt.plot(X[idx, 0], X[idx, 1], linestyle='None', alpha=1, marker=markers[class_ids[c]],
                        markersize=10, label=c)
    legend = plt.legend(loc='upper right', shadow=True)
    # plt.title("tsne")
    # plt.savefig(save_file)
    plt.savefig(save_file, format='eps', dpi=2000)
    plt.show()
项目:KATE    作者:hugochan    | 项目源码 | 文件源码
def plot_tsne_3d(doc_codes, doc_labels, classes_to_visual, save_file, maker_size=None, opaque=None):
    markers = ["D", "p", "*", "s", "d", "8", "^", "H", "v", ">", "<", "h", "|"]
    plt.rc('legend',**{'fontsize':20})
    colors = ['r', 'b', 'g', 'c', 'm', 'y', 'k']
    C = len(classes_to_visual)
    while True:
        if C <= len(markers):
            break
        markers += markers
    while True:
        if C <= len(colors):
            break
        colors += colors

    class_ids = dict(zip(classes_to_visual, range(C)))

    if isinstance(doc_codes, dict) and isinstance(doc_labels, dict):
        codes, labels = zip(*[(code, doc_labels[doc]) for doc, code in doc_codes.items() if doc_labels[doc] in classes_to_visual])
    else:
        codes, labels = doc_codes, doc_labels

    X = np.r_[list(codes)]
    tsne = TSNE(perplexity=30, n_components=3, init='pca', n_iter=5000)
    np.set_printoptions(suppress=True)
    X = tsne.fit_transform(X)

    fig = plt.figure(figsize=(10, 10), facecolor='white')
    ax = fig.add_subplot(111, projection='3d')

    # The problem is that the legend function don't support the type returned by a 3D scatter.
    # So you have to create a "dummy plot" with the same characteristics and put those in the legend.
    scatter_proxy = []
    for i in range(C):
        cls = classes_to_visual[i]
        idx = np.array(labels) == cls
        ax.scatter(X[idx, 0], X[idx, 1], X[idx, 2], c=colors[i], alpha=opaque[i] if opaque else 1, s=maker_size[i] if maker_size else 20, marker=markers[i], label=cls)
        scatter_proxy.append(mpl.lines.Line2D([0],[0], linestyle="none", c=colors[i], marker=markers[i], label=cls))
    ax.legend(scatter_proxy, classes_to_visual, numpoints=1)
    plt.savefig(save_file)
    plt.show()
项目:OASIS    作者:j-friedrich    | 项目源码 | 文件源码
def init_fig():
    """change some defaults for plotting"""
    plt.rc('figure', facecolor='white', dpi=90, frameon=False)
    plt.rc('font', size=30, **{'family': 'sans-serif', 'sans-serif': ['Computer Modern']})
    plt.rc('lines', lw=2)
    plt.rc('text', usetex=True)
    plt.rc('legend', **{'fontsize': 24, 'frameon': False, 'labelspacing': .3, 'handletextpad': .3})
    plt.rc('axes', linewidth=2)
    plt.rc('xtick.major', size=10, width=1.5)
    plt.rc('ytick.major', size=10, width=1.5)
项目:image_recognition    作者:tue-robotics    | 项目源码 | 文件源码
def plot_false_positive_true_positive_rates(labels, classifications_ground_truth_as_score_matrix,
                                            classifications_scores):
    """
    Plot the false positive true positive rates per label
    :param labels: Input labels
    :param classifications_ground_truth_as_score_matrix: Zero score matrix with ones on the ground truth
    :param classifications_scores: The classification scores per label
    """

    # setup consistent colors + markers for each label
    colors_per_label = dict(zip(labels, [plt.get_cmap('gist_rainbow')(i) for i in np.linspace(0, 1, len(labels))]))
    markers_per_label = dict(zip(labels, itertools.cycle([' ', 'o', 'x'])))

    plt.figure()
    plt.rc("axes", labelsize=15)

    # Compute ROC curve and ROC area for each class
    false_positive_rate_per_label = {}
    true_positive_rate_per_label = {}
    unknown_thresholds_per_label = {}
    for i, label in enumerate(labels):
        false_positive_rate_per_label[label], true_positive_rate_per_label[label], unknown_thresholds_per_label[
            label] = \
            roc_curve(classifications_ground_truth_as_score_matrix[:, i], classifications_scores[:, i])

    for label in labels:
        plt.plot(unknown_thresholds_per_label[label], true_positive_rate_per_label[label],
                 color=colors_per_label[label], marker=markers_per_label[label], label='Tpr {}'.format(label))
        plt.plot(unknown_thresholds_per_label[label], false_positive_rate_per_label[label],
                 color=colors_per_label[label], marker=markers_per_label[label], linestyle='dashed',
                 label='Fpr {}'.format(label))

    plt.legend()
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.ylabel(r'True  $\frac{T_p(t)}{T_p(t) + F_p(t)}$ & False $\frac{F_p(t)}{F_p(t) + T_n(t)}$ Positive Rate: ')
    plt.xlabel('Threshold ($t$)')
    plt.title('Threshold vs. True & False positive rate')
项目:scipyplot    作者:robertocalandra    | 项目源码 | 文件源码
def niceFigure():
    rcParams.update({'figure.autolayout': True})
    # plt.rc('text', usetex=True)
    # plt.rcParams['text.latex.preamble'] = [r"\usepackage{amsmath}"]
    rcParams['xtick.direction'] = 'out'
    rcParams['ytick.direction'] = 'out'
项目:QDREN    作者:andreamad8    | 项目源码 | 文件源码
def plot_dist(train_y,dev_y,test_y):
    import seaborn as sns
    import matplotlib.pyplot as plt
    plt.rc('text', usetex=True)
    plt.rc('font', family='Times-Roman')
    sns.set_style(style='white')
    color = sns.color_palette("Set2", 10)
    fig = plt.figure(figsize=(8,12))

    ax1 = fig.add_subplot(3, 1, 1)
    # plt.title("Label distribution",fontsize=20)
    sns.distplot(train_y,kde=False,label='Training', hist=True, norm_hist=True,color="blue")
    ax1.set_xlabel("Answer")
    ax1.set_ylabel("Frequency")
    ax1.set_xlim([0,500])
    plt.legend(loc='best')

    ax2 = fig.add_subplot(3, 1, 2)
    sns.distplot(dev_y,kde=False,label='Validation', hist=True, norm_hist=True,color="green")
    ax2.set_xlabel("Answer")
    ax2.set_ylabel("Frequency")
    ax2.set_xlim([0,500])
    plt.legend(loc='best')

    ax3 = fig.add_subplot(3, 1, 3)
    sns.distplot(test_y,kde=False,label='Test', hist=True, norm_hist=True,color="red")
    ax3.set_xlabel("Answer")
    ax3.set_ylabel("Frequency")
    ax3.set_xlim([0,500])
    plt.legend(loc='best')



    plt.savefig('checkpoints/label_dist.pdf', format='pdf', dpi=300)

    plt.show()
项目:Project101    作者:Wonjuseo    | 项目源码 | 文件源码
def plotting(epochs,val_acc):
    plt.rc('font',family='serif')
    fig = plt.figure()
    plt.plot(range(epochs),val_acc,label='acc',color='black')
    plt.show()
    plt.savefig('mnist_BiRNN.png')
项目:BayesVP    作者:cameronliang    | 项目源码 | 文件源码
def plot_model_comparison(self,redshift,dv,central_wave=None):
        """
        Plot best fit model onto spectrum for visual inspection 
        """
        c = 299792.485 # [km/s]

        if central_wave == None:
            # Use the first transition as the central wavelength
            central_wave = self.config_param.transitions_params_array[0][0][0][1]
        else:
            central_wave = float(central_wave)

        obs_spec_wave = self.config_param.wave / (1+redshift) 
        obs_spec_dv = c*(obs_spec_wave - central_wave) / central_wave
        plt.rc('text', usetex=True)

        plt.figure(1)
        plt.step(obs_spec_dv,self.config_param.flux,'k',label=r'$\rm Data$')
        plt.step(obs_spec_dv,self.model_flux,'b',lw=2,label=r'$\rm Best\,Fit$')
        plt.step(obs_spec_dv,self.config_param.dflux,'r')
        plt.axhline(1,ls='--',c='g',lw=1.2)
        plt.axhline(0,ls='--',c='g',lw=1.2)
        plt.ylim([-0.1,1.4])
        plt.xlim([-dv,dv])
        plt.xlabel(r'$dv\,[\rm km/s]$')
        plt.ylabel(r'$\rm Normalized\,Flux$')
        plt.legend(loc=3)

        output_name = self.config_param.processed_product_path + '/modelspec_' + self.config_param.chain_short_fname + '.pdf' 
        plt.savefig(output_name,bbox_inches='tight',dpi=100)
        plt.clf()
        print('Written %s' % output_name)
项目:PonyGE2    作者:PonyGE    | 项目源码 | 文件源码
def save_box_plot(data, names, title):
    """
    Given an array of some data, and a list of names of that data, generate
    and save a box plot of that data.

    :param data: An array of some data to be plotted.
    :param names: A list of names of that data.
    :param title: The title of the plot.
    :return: Nothing
    """

    from algorithm.parameters import params

    import matplotlib.pyplot as plt
    plt.rc('font', family='Times New Roman')

    # Set up the figure.
    fig = plt.figure()
    ax1 = fig.add_subplot(1, 1, 1)

    # Plot tight layout.
    plt.tight_layout()

    # Plot the data.
    ax1.boxplot(np.transpose(data), 1)

    # Plot title.
    plt.title(title)

    # Generate list of numbers for plotting names.
    nums = list(range(len(data))[1:]) + [len(data)]

    # Plot names for each data point.
    plt.xticks(nums, names, rotation='vertical', fontsize=8)

    # Save plot.
    plt.savefig(path.join(params['FILE_PATH'], (title + '.pdf')))

    # Close plot.
    plt.close()
项目:GASP-python    作者:henniggroup    | 项目源码 | 文件源码
def get_system_size_plot(self):
        """
        Returns a plot of the system size versus the number of energy
        calculations, as a matplotlib plot object.
        """

        # set the font to Times, rendered with Latex
        plt.rc('font', **{'family': 'serif', 'serif': ['Times']})
        plt.rc('text', usetex=True)

        # parse the compositions and numbers of energy calculations
        compositions = []
        num_calcs = []
        for i in range(4, len(self.lines)):
            line = self.lines[i].split()
            compositions.append(line[1])
            num_calcs.append(int(line[4]))

        # get the numbers of atoms from the compositions
        nums_atoms = []
        for composition in compositions:
            comp = Composition(composition)
            nums_atoms.append(comp.num_atoms)

        # make the plot
        plt.plot(num_calcs, nums_atoms, 'D', markersize=5,
                 markeredgecolor='blue', markerfacecolor='blue')
        plt.xlabel(r'Number of energy calculations', fontsize=22)
        plt.ylabel(r'Number of atoms in the cell', fontsize=22)
        plt.tick_params(which='both', width=1, labelsize=18)
        plt.tick_params(which='major', length=8)
        plt.tick_params(which='minor', length=4)
        plt.xlim(xmin=0)
        plt.ylim(ymin=0)
        plt.tight_layout()
        return plt
项目:python-mrcz    作者:em-MRCZ    | 项目源码 | 文件源码
def plotFSC( self ):
        # Do error checking?  Or no?
        plt.rc('lines', linewidth=2.0, markersize=12.0 )
        plt.figure()
        plt.plot( self.star['data_fsc']['Resolution'], 0.143*np.ones_like(self.star['data_fsc']['Resolution']), 
                 '-', color='firebrick', label="Resolution criteria" )
        try:
            plt.plot( self.star['data_fsc']['Resolution'], self.star['data_fsc']['FourierShellCorrelationUnmaskedMaps'], 
                 'k.-', label="Unmasked FSC" )
        except: pass
        try:
            plt.plot( self.star['data_fsc']['Resolution'], self.star['data_fsc']['FourierShellCorrelationMaskedMaps'], 
                 '.-', color='royalblue', label="Masked FSC" )   
        except: pass
        try:         
            plt.plot( self.star['data_fsc']['Resolution'], self.star['data_fsc']['FourierShellCorrelationCorrected'], 
                 '.-', color='forestgreen', label="Corrected FSC" )          
        except: pass
        try:
            plt.plot( self.star['data_fsc']['Resolution'], self.star['data_fsc']['CorrectedFourierShellCorrelationPhaseRandomizedMaskedMaps'], 
                 '.-', color='goldenrod', label="Random-phase corrected FSC" )
        except: pass
        plt.xlabel( "Resolution ($\AA^{-1}$)" )
        plt.ylabel( "Fourier Shell Correlation" )
        plt.legend( loc='upper right', fontsize=16 )
        plt.xlim( np.min(self.star['data_fsc']['Resolution']), np.max(self.star['data_fsc']['Resolution']) )
        print( "Final resolution (unmasked): %.2f A"%self.star['data_general']['FinalResolution']  )
        print( "B-factor applied: %.1f"%self.star['data_general']['BfactorUsedForSharpening'] )
项目:PyBGMM    作者:junlulocky    | 项目源码 | 文件源码
def plot_variance_cuve():
        ################# plot for variance ##########################
        K = 5
        alpha = a = 0.3
        all_num = 1000
        b = np.linspace(0, 20, num=all_num)

        # print upper_incomplete_gamma_function(0.1, 1)
        # print uppergamma(0.1, 1)
        # a = uppergamma(0.1, 1)
        # print float(a)
        # print compute_gdir_variance(K, a, 1)

        symmetric_dir_var = [compute_symmetric_dir_variance(K, alpha)] * all_num
        gdir_var = [compute_gdir_variance(K, a, local_b) for local_b in b]
        # print gdir_var

        save_path = os.path.dirname(__file__) + '/res_gdir/res_variance'


        plt.figure(1)

        plt.rc('xtick', labelsize=20)
        plt.rc('ytick', labelsize=20)
        plt.tick_params(axis='both', which='major', labelsize=20)
        plt.plot(b, gdir_var)
        plt.plot(b, symmetric_dir_var)
        plt.plot((a, a), (0, 1./K), 'k-')

        plt.savefig(save_path + '/gdir_K{}_a{}.png'.format(K, a))
        plt.savefig(save_path + '/gdir_K{}_a{}.pdf'.format(K, a))



        plt.show()
项目:OTC3D    作者:tiffanyts    | 项目源码 | 文件源码
def contour(self,title='',cbartitle = '',model=[], zmax = None, zmin = None, filename = None, resolution = 1, unit_str = '', bar = True):
        """ Returns a figure with contourplot of 2D spatial data. Insert filename to save the figure as an image. Increase resolution to increase detail of interpolated data (<1 to decrease)"""

        font = {'weight' : 'medium',
                'size'   : 22}

        xi = np.linspace(min(self.data.x), max(self.data.x),len(set(self.data.x))*resolution)
        yi = np.linspace(min(self.data.y), max(self.data.y),len(set(self.data.y))*resolution)


        zi = ml.griddata(self.data.x, self.data.y, self.data.v.interpolate(), xi, yi,interp='linear')

        fig = plt.figure()
        plt.rc('font', **font)
        plt.title(title)
        plt.contour(xi, yi, zi, 15, linewidths = 0, cmap=plt.cm.bone)
        plt.pcolormesh(xi, yi, zi, cmap = plt.get_cmap('rainbow'),vmax = zmax, vmin = zmin)
        if bar: cbar = plt.colorbar(); cbar.ax.set_ylabel(cbartitle)

        plt.absolute_import
        try:
            vertices = [(vertex.X(), vertex.Y()) for vertex in pyliburo.py3dmodel.fetch.vertex_list_2_point_list(pyliburo.py3dmodel.fetch.topos_frm_compound(model)["vertex"])]
            shape = patches.PathPatch(Path(vertices), facecolor='white', lw=0)
            plt.gca().add_patch(shape)
        except TypeError:
            pass
        plt.show()

        try:
            fig.savefig(filename)
        except TypeError:
            return fig

#    def plot_along_line(self,X,Y, tick_list):
#        V = self.data.v
#        plt.plot(heights, SVFs_can, label='Canyon')
项目:spyking-circus-ort    作者:spyking-circus    | 项目源码 | 文件源码
def plot(self):
        plt.rc('text', usetex=True)
        plt.rc('font', family='serif')
        self.plot_spatial_configuration()
        self.plot_temporal_configuration(t_start=0.0, t_end=1.0)
        self.plot_waveforms()
        return
项目:swn-gen    作者:jimmayjr    | 项目源码 | 文件源码
def stats():
    grouperLabels = ['Random',
                     'Min Dist Stars',
                     'Max Dist Stars',
                     '1/4 Min Dist Stars',
                     '1/3 Min Dist Stars',
                     '1/2 Min Dist Stars',
                     'Link Most Isolated Group',
                     'Link Smallest Group',
                     'Link Largest Group']
    # Queue for returning counts
    q = mp.Queue()
    # Create processes
    pList = list()
    for gType in xrange(9):
        p = mp.Process(target=statsgen,args=(q,gType))
        pList.append(p)
        p.start()
    # Join processes
    countsList = list()
    for gType in xrange(9):
        print('Grouper Method ' + str(gType))
        pList[gType].join()
        countsList.append(q.get())

    # Plot statistics
    font = {'size'   : 8}
    plt.rc('font', **font)
    plt.figure(figsize=(8,10))
    for gType in xrange(9):
        plt.subplot(3,3,countsList[gType][0]+1)
        plt.title(str(countsList[gType][0]) + ' - ' + grouperLabels[countsList[gType][0]],fontsize=8)
        plt.imshow(countsList[gType][1])
    plt.savefig('groupingStats.png')
项目:hco-experiments    作者:zooniverse    | 项目源码 | 文件源码
def measure_FoM(X, y, classifier, plot=True):
    pred = classifier.predict_proba(X)[:,1]
    fpr, tpr, thresholds = roc_curve(y, pred)

    FoM = 1-tpr[np.where(fpr<=0.01)[0][-1]]
    print("[+] FoM: %.4f" % (FoM))
    threshold = thresholds[np.where(fpr<=0.01)[0][-1]]
    print("[+] threshold: %.4f" % (threshold))
    print()

    if plot:
        font = {"size": 18}
        plt.rc("font", **font)
        plt.rc("legend", fontsize=14)

        plt.xlabel("Missed Detection Rate (MDR)")
        plt.ylabel("False Positive Rate (FPR)")
        plt.yticks([0, 0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 1.0])
        plt.ylim((0,1.05))

        plt.plot(1-tpr, fpr, "k-", lw=5)
        plt.plot(1-tpr, fpr, color="#FF0066", lw=4)

        plt.plot([x for x in np.arange(0,FoM+1e-3,1e-3)], \
                  0.01*np.ones(np.shape(np.array([x for x in np.arange(0,FoM+1e-3,1e-3)]))), \
                 "k--", lw=3)

        plt.plot(FoM*np.ones((11,)), [x for x in np.arange(0,0.01+1e-3, 1e-3)], "k--", lw=3)

        plt.xticks([0, 0.05, 0.10, 0.25, FoM], rotation=70)

        locs, labels = plt.xticks()
        plt.xticks(locs, ["%.3f" % x for x in locs])
        plt.show()
    return FoM, threshold
项目:hco-experiments    作者:zooniverse    | 项目源码 | 文件源码
def measure_FoM(X, y, classifier, plot=True):
    pred = classifier.predict_proba(X)[:,1]
    fpr, tpr, thresholds = roc_curve(y, pred)

    FoM = 1-tpr[np.where(fpr<=0.01)[0][-1]]
    print("[+] FoM: %.4f" % (FoM))
    threshold = thresholds[np.where(fpr<=0.01)[0][-1]]
    print("[+] threshold: %.4f" % (threshold))
    print()

    if plot:
        font = {"size": 18}
        plt.rc("font", **font)
        plt.rc("legend", fontsize=14)

        plt.xlabel("Missed Detection Rate (MDR)")
        plt.ylabel("False Positive Rate (FPR)")
        plt.yticks([0, 0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 1.0])
        plt.ylim((0,1.05))

        plt.plot(1-tpr, fpr, "k-", lw=5)
        plt.plot(1-tpr, fpr, color="#FF0066", lw=4)

        plt.plot([x for x in np.arange(0,FoM+1e-3,1e-3)], \
                  0.01*np.ones(np.shape(np.array([x for x in np.arange(0,FoM+1e-3,1e-3)]))), \
                 "k--", lw=3)

        plt.plot(FoM*np.ones((11,)), [x for x in np.arange(0,0.01+1e-3, 1e-3)], "k--", lw=3)

        plt.xticks([0, 0.05, 0.10, 0.25, FoM], rotation=70)

        locs, labels = plt.xticks()
        plt.xticks(locs, ["%.3f" % x for x in locs])
        plt.show()
    return FoM, threshold
项目:hco-experiments    作者:zooniverse    | 项目源码 | 文件源码
def measure_FoM(X, y, classifier, plot=True):
    pred = classifier.predict_proba(X)[:,1]
    fpr, tpr, thresholds = roc_curve(y, pred)

    FoM = 1-tpr[np.where(fpr<=0.01)[0][-1]]
    print("[+] FoM: %.4f" % (FoM))
    threshold = thresholds[np.where(fpr<=0.01)[0][-1]]
    print("[+] threshold: %.4f" % (threshold))
    print()

    if plot:
        font = {"size": 18}
        plt.rc("font", **font)
        plt.rc("legend", fontsize=14)

        plt.xlabel("Missed Detection Rate (MDR)")
        plt.ylabel("False Positive Rate (FPR)")
        plt.yticks([0, 0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 1.0])
        plt.ylim((0,1.05))

        plt.plot(1-tpr, fpr, "k-", lw=5)
        plt.plot(1-tpr, fpr, color="#FF0066", lw=4)

        plt.plot([x for x in np.arange(0,FoM+1e-3,1e-3)], \
                  0.01*np.ones(np.shape(np.array([x for x in np.arange(0,FoM+1e-3,1e-3)]))), \
                 "k--", lw=3)

        plt.plot(FoM*np.ones((11,)), [x for x in np.arange(0,0.01+1e-3, 1e-3)], "k--", lw=3)

        plt.xticks([0, 0.05, 0.10, 0.25, FoM], rotation=70)

        locs, labels = plt.xticks()
        plt.xticks(locs, ["%.3f" % x for x in locs])
        plt.show()
    return FoM, threshold
项目:hco-experiments    作者:zooniverse    | 项目源码 | 文件源码
def measure_FoM(X, y, classifier, plot=True):
    pred = classifier.predict_proba(X)[:,1]
    fpr, tpr, thresholds = roc_curve(y, pred)

    FoM = 1-tpr[np.where(fpr<=0.01)[0][-1]]
    print "[+] FoM: %.4f" % (FoM)
    threshold = thresholds[np.where(fpr<=0.01)[0][-1]]
    print "[+] threshold: %.4f" % (threshold)
    print

    if plot:
        font = {"size": 18}
        plt.rc("font", **font)
        plt.rc("legend", fontsize=14)

        plt.xlabel("Missed Detection Rate (MDR)")
        plt.ylabel("False Positive Rate (FPR)")
        plt.yticks([0, 0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 1.0])
        plt.ylim((0,1.05))

        plt.plot(1-tpr, fpr, "k-", lw=5)
        plt.plot(1-tpr, fpr, color="#FF0066", lw=4)

        plt.plot([x for x in np.arange(0,FoM+1e-3,1e-3)], \
                  0.01*np.ones(np.shape(np.array([x for x in np.arange(0,FoM+1e-3,1e-3)]))), \
                 "k--", lw=3)

        plt.plot(FoM*np.ones((11,)), [x for x in np.arange(0,0.01+1e-3, 1e-3)], "k--", lw=3)

        plt.xticks([0, 0.05, 0.10, 0.25, FoM], rotation=70)

        locs, labels = plt.xticks()
        plt.xticks(locs, map(lambda x: "%.3f" % x, locs))
        plt.show()
    return FoM, threshold
项目:hco-experiments    作者:zooniverse    | 项目源码 | 文件源码
def measure_FoM(X, y, classifier, plot=True):
    pred = classifier.predict_proba(X)[:,1]
    fpr, tpr, thresholds = roc_curve(y, pred)

    FoM = 1-tpr[np.where(fpr<=0.01)[0][-1]]
    print "[+] FoM: %.4f" % (FoM)
    threshold = thresholds[np.where(fpr<=0.01)[0][-1]]
    print "[+] threshold: %.4f" % (threshold)
    print

    if plot:
        font = {"size": 18}
        plt.rc("font", **font)
        plt.rc("legend", fontsize=14)

        plt.xlabel("Missed Detection Rate (MDR)")
        plt.ylabel("False Positive Rate (FPR)")
        plt.yticks([0, 0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 1.0])
        plt.ylim((0,1.05))

        plt.plot(1-tpr, fpr, "k-", lw=5)
        plt.plot(1-tpr, fpr, color="#FF0066", lw=4)

        plt.plot([x for x in np.arange(0,FoM+1e-3,1e-3)], \
                  0.01*np.ones(np.shape(np.array([x for x in np.arange(0,FoM+1e-3,1e-3)]))), \
                 "k--", lw=3)

        plt.plot(FoM*np.ones((11,)), [x for x in np.arange(0,0.01+1e-3, 1e-3)], "k--", lw=3)

        plt.xticks([0, 0.05, 0.10, 0.25, FoM], rotation=70)

        locs, labels = plt.xticks()
        plt.xticks(locs, map(lambda x: "%.3f" % x, locs))
        plt.show()
    return FoM, threshold
项目:hco-experiments    作者:zooniverse    | 项目源码 | 文件源码
def measure_FoM(X, y, classifier, plot=True):
    pred = classifier.predict_proba(X)[:,1]
    fpr, tpr, thresholds = roc_curve(y, pred)

    FoM = 1-tpr[np.where(fpr<=0.01)[0][-1]]
    print "[+] FoM: %.4f" % (FoM)
    threshold = thresholds[np.where(fpr<=0.01)[0][-1]]
    print "[+] threshold: %.4f" % (threshold)
    print

    if plot:
        font = {"size": 18}
        plt.rc("font", **font)
        plt.rc("legend", fontsize=14)

        plt.xlabel("Missed Detection Rate (MDR)")
        plt.ylabel("False Positive Rate (FPR)")
        plt.yticks([0, 0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 1.0])
        plt.ylim((0,1.05))

        plt.plot(1-tpr, fpr, "k-", lw=5)
        plt.plot(1-tpr, fpr, color="#FF0066", lw=4)

        plt.plot([x for x in np.arange(0,FoM+1e-3,1e-3)], \
                  0.01*np.ones(np.shape(np.array([x for x in np.arange(0,FoM+1e-3,1e-3)]))), \
                 "k--", lw=3)

        plt.plot(FoM*np.ones((11,)), [x for x in np.arange(0,0.01+1e-3, 1e-3)], "k--", lw=3)

        plt.xticks([0, 0.05, 0.10, 0.25, FoM], rotation=70)

        locs, labels = plt.xticks()
        plt.xticks(locs, map(lambda x: "%.3f" % x, locs))
        plt.show()
    return FoM, threshold
项目:SoftSAR    作者:eduardosufan    | 项目源码 | 文件源码
def plot_trajectory_2D(traj):
    """
    Plot airplane trajectory in 2D.

    Parameters
    ----------
    traj: object (AirplaneTrajectory instance).
      AirplaneTrajectory instance.

    Returns
    -------
    -.
    """

    font = {'weight' : 'bold',
            'size'   : 18}

    plt.rc('font', **font)

    plt.figure()
    plt.subplot(211)
    plt.plot(traj.flight_x, 'r--', label='x')
    plt.plot(traj.flight_y, 'g--', label='y')
    plt.plot(traj.flight_z, 'b--', label='z')
    plt.legend()
    #plt.xlabel('Samples in space')
    plt.ylabel('Position')
    plt.title('Position of airplane')

    plt.subplot(212)
    plt.plot(traj.flight_vx, 'r--', label='vx')
    plt.plot(traj.flight_vy, 'g--', label='vy')
    plt.plot(traj.flight_vz, 'b--', label='vz')
    plt.legend()
    plt.xlabel('Samples in space')
    plt.ylabel('Velocity')
    plt.title('Velocity of airplane')

    plt.show()
项目:Python-Data-Analytics-and-Visualization    作者:PacktPublishing    | 项目源码 | 文件源码
def plotSingleTickerWithVolume(ticker, startdate, enddate):

    global ax

    fh = finance.fetch_historical_yahoo(ticker, startdate, enddate)

    # a numpy record array with fields: 
    #     date, open, high, low, close, volume, adj_close
    r = mlab.csv2rec(fh); 
    fh.close()
    r.sort()

    plt.rc('axes', grid=True)
    plt.rc('grid', color='0.78', linestyle='-', linewidth=0.5)

    axt = ax.twinx()
    prices = r.adj_close

    fcolor = 'darkgoldenrod'

    ax.plot(r.date, prices, color=r'#1066ee', lw=2, label=ticker)
    ax.fill_between(r.date, prices, 0, prices, facecolor='#BBD7E5')
    ax.set_ylim(0.5*prices.max())

    ax.legend(loc='upper right', shadow=True, fancybox=True)

    volume = (r.close*r.volume)/1e6  # dollar volume in millions
    vmax = volume.max()

    axt.fill_between(r.date, volume, 0, label='Volume', 
                 facecolor=fcolor, edgecolor=fcolor)

    axt.set_ylim(0, 5*vmax)
    axt.set_yticks([])

    for axis in ax, axt:  
        for label in axis.get_xticklabels():
            label.set_rotation(30)
            label.set_horizontalalignment('right')

        axis.fmt_xdata = mdates.DateFormatter('%Y-%m-%d')
项目:Python-Data-Analytics-and-Visualization    作者:PacktPublishing    | 项目源码 | 文件源码
def plotSingleTickerWithVolume(ticker, startdate, enddate):

    global ax

    fh = finance.fetch_historical_yahoo(ticker, startdate, enddate)

    # a numpy record array with fields: 
    #     date, open, high, low, close, volume, adj_close
    r = mlab.csv2rec(fh); 
    fh.close()
    r.sort()

    plt.rc('axes', grid=True)
    plt.rc('grid', color='0.78', linestyle='-', linewidth=0.5)

    axt = ax.twinx()
    prices = r.adj_close

    fcolor = 'darkgoldenrod'

    ax.plot(r.date, prices, color=r'#1066ee', lw=2, label=ticker)
    ax.fill_between(r.date, prices, 0, prices, facecolor='#BBD7E5')
    ax.set_ylim(0.5*prices.max())

    ax.legend(loc='upper right', shadow=True, fancybox=True)

    volume = (r.close*r.volume)/1e6  # dollar volume in millions
    vmax = volume.max()

    axt.fill_between(r.date, volume, 0, label='Volume', 
                 facecolor=fcolor, edgecolor=fcolor)

    axt.set_ylim(0, 5*vmax)
    axt.set_yticks([])

    for axis in ax, axt:  
        for label in axis.get_xticklabels():
            label.set_rotation(30)
            label.set_horizontalalignment('right')

        axis.fmt_xdata = mdates.DateFormatter('%Y-%m-%d')
项目:easyesn    作者:kalekiu    | 项目源码 | 文件源码
def pred(predictionHorizon):
    print("predicting x(t+{0})".format(predictionHorizon))
    #optimized for: predictionHorizon = 48
    y_train = y[:2000]
    y_test = y[2000-predictionHorizon:4000]

    #manual optimization
    #esn = ESN(n_input=1, n_output=1, n_reservoir=1000, noise_level=0.001, spectral_radius=.4, leak_rate=0.2, random_seed=42, sparseness=0.2)

    #gridsearch results
    esn = ESN(n_input=1, n_output=1, n_reservoir=1000, noise_level=0.0001, spectral_radius=1.35, leak_rate=0.7, random_seed=42, sparseness=0.2, solver="lsqr", regression_parameters=[1e-8])
    train_acc = esn.fit(inputData=y_train[:-predictionHorizon], outputData=y_train[predictionHorizon:], transient_quota = 0.2)
    print("training acc: {0:4f}\r\n".format(train_acc))

    y_test_pred = esn.predict(y_test[:-predictionHorizon])

    mse = np.mean( (y_test_pred-y_test[predictionHorizon:])[:]**2)
    rmse = np.sqrt(mse)
    nrmse = rmse/np.var(y_test)
    print("testing mse: {0}".format(mse))
    print("testing rmse: {0:4f}".format(rmse))
    print("testing nrmse: {0:4f}".format(nrmse))

    import matplotlib
    plt.rc('font', **{'family': 'serif', 'serif': ['Computer Modern'], 'size': 13})
    plt.rc('text', usetex=True)
    plt.rc('text.latex', preamble="\\usepackage{mathtools}")

    plt.figure(figsize=(8,5))
    plt.plot(y_test[predictionHorizon:], 'r', linestyle=":" )
    plt.plot(y_test_pred, 'b' , linestyle="--")
    plt.ylim([0.3, 1.6])
    plt.legend(['Signal $x(t)$', 'Vorhersage $x\'(t) \\approx x(t+{0})$'.format(predictionHorizon)],
          fancybox=True, shadow=True, ncol=2, loc="upper center")
    plt.xlabel("Zeit t")
    plt.ylabel("Signal")

    plt.show()

    return mse
项目:KATE    作者:hugochan    | 项目源码 | 文件源码
def visualize_pca_2d(doc_codes, doc_labels, classes_to_visual, save_file):
    """
        Visualize the input data on a 2D PCA plot. Depending on the number of components,
        the plot will contain an X amount of subplots.
        @param doc_codes:
        @param number_of_components: The number of principal components for the PCA plot.
    """
    # markers = ["D", "p", "*", "s", "d", "8", "^", "H", "v", ">", "<", "h", "|"]
    markers = ["o", "v", "8", "s", "p", "*", "h", "H", "+", "x", "D"]
    plt.rc('legend',**{'fontsize':28})
    classes_to_visual = list(set(classes_to_visual))
    C = len(classes_to_visual)
    while True:
        if C <= len(markers):
            break
        markers += markers

    class_ids = dict(zip(classes_to_visual, range(C)))

    if isinstance(doc_codes, dict) and isinstance(doc_labels, dict):
        codes, labels = zip(*[(code, doc_labels[doc]) for doc, code in doc_codes.items() if doc_labels[doc] in classes_to_visual])
    else:
        codes, labels = doc_codes, doc_labels

    X = np.r_[list(codes)]
    X = PCA(n_components=3).fit_transform(X)
    plt.figure(figsize=(10, 10), facecolor='white')

    x_pc, y_pc = 1, 2

    for c in classes_to_visual:
        idx = np.array(labels) == c
        # idx = get_indices(labels, c)
        plt.plot(X[idx, x_pc], X[idx, y_pc], linestyle='None', alpha=1, marker=markers[class_ids[c]],
                        markersize=10, label=c)
        # plt.legend(c)
    # plt.title('Projected on the PCA components')
    # plt.xlabel('PC %s' % x_pc)
    # plt.ylabel('PC %s' % y_pc)
    legend = plt.legend(loc='upper right', shadow=True)
    # plt.savefig(save_file)
    plt.savefig(save_file, format='eps', dpi=2000)
    plt.show()
项目:KATE    作者:hugochan    | 项目源码 | 文件源码
def visualize_pca_3d(doc_codes, doc_labels, classes_to_visual, save_file, maker_size=None, opaque=None):
    """
        Visualize the input data on a 2D PCA plot. Depending on the number of components,
        the plot will contain an X amount of subplots.
        @param doc_codes:
        @param number_of_components: The number of principal components for the PCA plot.
    """
    markers = ["D", "p", "*", "s", "d", "8", "^", "H", "v", ">", "<", "h", "|"]
    plt.rc('legend',**{'fontsize':20})
    colors = ['r', 'b', 'g', 'c', 'm', 'y', 'k']
    C = len(classes_to_visual)
    while True:
        if C <= len(markers):
            break
        markers += markers
    while True:
        if C <= len(colors):
            break
        colors += colors

    if isinstance(doc_codes, dict) and isinstance(doc_labels, dict):
        codes, labels = zip(*[(code, doc_labels[doc]) for doc, code in doc_codes.items() if doc_labels[doc] in classes_to_visual])
    else:
        codes, labels = doc_codes, doc_labels

    X = np.r_[list(codes)]
    X = PCA(n_components=3).fit_transform(X)
    fig = plt.figure(figsize=(10, 10), facecolor='white')
    ax = fig.add_subplot(111, projection='3d')
    x_pc, y_pc, z_pc = 0, 1, 2

    # The problem is that the legend function don't support the type returned by a 3D scatter.
    # So you have to create a "dummy plot" with the same characteristics and put those in the legend.
    scatter_proxy = []
    for i in range(C):
        cls = classes_to_visual[i]
        idx = np.array(labels) == cls
        ax.scatter(X[idx, x_pc], X[idx, y_pc], X[idx, z_pc], c=colors[i], alpha=opaque[i] if opaque else 1, s=maker_size[i] if maker_size else 20, marker=markers[i], label=cls)
        scatter_proxy.append(mpl.lines.Line2D([0],[0], linestyle="none", c=colors[i], marker=markers[i], label=cls))
    ax.legend(scatter_proxy, classes_to_visual, numpoints=1)
    # plt.title('Projected on the PCA components')
    ax.set_xlabel('%sst component' % (x_pc + 1), fontsize=14)
    ax.set_ylabel('%snd component' % (y_pc + 1), fontsize=14)
    ax.set_zlabel('%srd component' % (z_pc + 1), fontsize=14)
    plt.savefig(save_file)
    plt.show()
项目:NuGridPy    作者:NuGrid    | 项目源码 | 文件源码
def set_plot_hrd(self,fig,symbs_1=[],linestyle=[],markevery=500,end_model=[],single_plot=True,labelmassonly=False):

        '''
            Plots HRDs
            end_model - array, control how far in models a run is plottet, if -1 till end
            symbs_1  - set symbols of runs
        '''
        m=self.run_historydata
            i=0
        if len(symbs_1)>0:
            symbs=symbs_1
        else:
            symbs=self.symbs
        if len(linestyle)==0:
            linestyle=200*['-']
            for case in m:
            t1_model=-1
            print end_model[i]
            if not end_model[i] == -1:
                t1_model=end_model[i]
                print self.run_label[i],t1_model
            t0_model=case.get("model_number")[0]
                logTeff=case.get('log_Teff')[:(t1_model-t0_model)]
                logL=case.get('log_L')[:(t1_model-t0_model)]

            label=self.run_label[i]
            if labelmassonly == True:
                label=label.split('Z')[0][:-2]
            print 'label',label
            if single_plot==False:
                            figure(i+1)
                            plot(logTeff,logL,marker=symbs[i],label=label,linestyle=linestyle[i],markevery=markevery)
                case.xlimrev()
                ax = plt.gca()
                plt.rcParams.update({'font.size': 16})
                plt.rc('xtick', labelsize=16)
                plt.rc('ytick', labelsize=16)
                legend(loc=4)
                xlabel('log Teff',fontsize=18)
                ylabel('log L',fontsize=18)
            #plt.gca().invert_xaxis()   
            figure(fig)
            plot(logTeff,logL,marker=symbs[i],label=label,linestyle=linestyle[i],markevery=markevery)
            #case.xlimrev()
            ax = plt.gca()
            plt.rcParams.update({'font.size': 16})
                    plt.rc('xtick', labelsize=16)
                    plt.rc('ytick', labelsize=16)
                    legend(loc=4)
                    xlabel('log Teff',fontsize=18)
                    ylabel('log L',fontsize=18)
            #case.xlimrev()
            #plt.gca().invert_xaxis()
            i+=1    
        figure(0)
        plt.gca().invert_xaxis()
项目:NuGridPy    作者:NuGrid    | 项目源码 | 文件源码
def set_plot_tcrhoc(self,symbs_1=[],linestyle=[],markevery=500,end_model=[]): 
        '''
            Plots HRDs
            end_model - array, control how far in models a run is plottet, if -1 till end
            symbs_1  - set symbols of runs
        '''
        m=self.run_historydata
            i=0
        if len(symbs_1)>0:
            symbs=symbs_1
        else:
            symbs=self.symbs
        if len(linestyle)==0:
            linestyle=200*['-']
            for case in m:
            t1_model=-1
            if end_model[i] != -1:
                t1_model=end_model[i]
            t0_model=case.get("model_number")[0]
            rho=case.get('log_center_Rho')[:(t1_model-t0_model)]

            T=case.get('log_center_T')[:(t1_model-t0_model)]

            h1=case.get('H-1')

                        figure(i+1)
                        #plot(logTeff,logL,marker=symbs[i],label=self.run_label[i],linestyle=linestyle[i],markevery=markevery)
            pl.plot(rho,T,marker=symbs[i],label=self.run_label[i],linestyle=linestyle[i],markevery=markevery)
            case.get            
            plt.gca().invert_xaxis()
            ax = plt.gca()
            plt.rcParams.update({'font.size': 16})
            plt.rc('xtick', labelsize=16)
            plt.rc('ytick', labelsize=16)
            legend(loc=4)
                plt.xlabel('log $\\rho_{\\rm c}$',fontsize=18)
                plt.ylabel('log $T_{\\rm c}$',fontsize=18)

            #plt.gca().invert_xaxis()   
            figure(0)
            pl.plot(rho,T,marker=symbs[i],label=self.run_label[i],linestyle=linestyle[i],markevery=markevery)
            plt.gca().invert_xaxis()
            ax = plt.gca()
            plt.rcParams.update({'font.size': 16})
                    plt.rc('xtick', labelsize=16)
                    plt.rc('ytick', labelsize=16)
                    legend(loc=4)
                plt.xlabel('log $\\rho_{\\rm c}$',fontsize=18)
                plt.ylabel('log $T_{\\rm c}$',fontsize=18)
            plt.gca().invert_xaxis()
            #plt.gca().invert_xaxis()
            i+=1
项目:NuGridPy    作者:NuGrid    | 项目源码 | 文件源码
def set_plot_mdot(self,fig,xaxis="model",masslabelonly=False,marker=[],markevery=500):

        '''
            Plots the mass loss vs time, model or mass
            xaxis: "model","time","mass" possible
        '''

                plt.rcParams.update({'font.size': 18})
                plt.rc('xtick', labelsize=18)
                plt.rc('ytick', labelsize=18)


        if xaxis=="time":
                        t0_model=self.set_find_first_TP()

        elif xaxis =="model" or xaxis=="mass":
            t0_model=len(self.run_historydata)*[0]
            m=self.run_historydata
        figure(fig)
        i=0
            for case in m:
            if len(marker)>0:
                marker1=marker[i]
            else:
                marker1=None    
            if xaxis=="time":

                            t0_time=case.get('star_age')[t0_model[i]]
                print t0_time
                star_mass=case.get('star_mass')[t0_model[i]]
                            star_age=case.get('star_age')[t0_model[i]:]  -t0_time
                mdot=case.get('log_abs_mdot')[t0_model[i]:]
                plt.plot(star_age,mdot,self.symbs[i],label=self.run_label[i],marker=marker1,markevery=markevery)
            elif xaxis=="model":
                mdot=case.get('log_abs_mdot')
                model=case.get('model_number')              
                    plt.plot(model,mdot,self.symbs[i],label=self.symbs[i],marker=marker1,markevery=markevery)
            elif xaxis=="mass":
                star_mass=case.get('star_mass')
                mdot=case.get('log_abs_mdot')
                if masslabelonly==True:
                    plt.plot(star_mass,mdot,self.symbs[i],label=self.run_label[i].split('Z')[0][:-2],marker=marker1,markevery=markevery)
                else:
                    plt.plot(star_mass,mdot,self.symbs[i],label=self.run_label[i],marker=marker1,markevery=markevery)
                #case.plot('star_mass','log_abs_mdot',legend=self.run_label[i],shape=self.symbs[i])
                i += 1
            legend(loc=2)
        if xaxis=="time":
            plt.xlabel('star age')
        elif xaxis=="model":
            plt.xlabel('model number')
        elif xaxis=="mass":
            xlabel('M/M$_{\odot}$',fontsize=18)
            ylabel('log($|\dot{M}|$)',fontsize=18)
            #ylim(-7,-3.5)
项目:NuGridPy    作者:NuGrid    | 项目源码 | 文件源码
def set_plot_kip_special(self,startfirstTP=True,xtime=True,label=[],color=[]):

        '''
            Kippenhahn which plots only h and he free bndry,
            label and color can be chosen.
            if label>0 then color must be set too!
            color=["r","b","g","k"]


        '''
        plt.rcParams.update({'font.size': 24})
        plt.rc('xtick', labelsize=24)
        plt.rc('ytick', labelsize=24)

                m=self.run_historydata
                i=0
                if startfirstTP==True:
                        t0_model=self.set_find_first_TP()

                else:
                    t0_model=len(self.run_historydata)*[0]

        for case in m:
            h1_boundary_mass  = case.get('h1_boundary_mass')[t0_model[i]:]
            he4_boundary_mass = case.get('he4_boundary_mass')[t0_model[i]:]
            star_mass         = case.get('star_mass')[t0_model[i]:]
            mx1_bot           = case.get('mx1_bot')[t0_model[i]:]*star_mass
            model = case.get("model_number")[t0_model[i]:]
            age=case.get("star_age")[t0_model[i]:]
            t0_age=age[0]
            age=age-t0_age
            if xtime==True:
                model=age
            if len(label)>0:
                plt.plot(model,h1_boundary_mass,color=color[i],label="$^1$H bndry, "+label[i])
                plt.plot(model,he4_boundary_mass,"--",color=color[i],label="$^4$He bndry, "+ label[i])
                #pltplot(model,star_mass,color=color[i],label="Total mass")
            else:
                plt.plot(model,h1_boundary_mass,label="h1_boundary_mass")
                plt.plot(model,he4_boundary_mass,"--",label="he4_boundary_mass")
            #plt.plot(model,mx1_bot,label="convective boundary")
            #title(self.run_label[i])
            i += 1
        if xtime==True:
            plt.xlabel('stellar age',size=28)
            if startfirstTP==True:
                plt.xlabel('t - t$_0$ $\mathrm{[yr]}$',size=28)
        else:   
            plt.xlabel('model number',size=28)
        plt.ylabel("M/M$_{\odot}$",size=28)     
        plt.legend()
项目:BirdCLEF2017    作者:kahst    | 项目源码 | 文件源码
def showConfusionMatrix(epoch):

    #new figure
    plt.figure(0, figsize=(35, 35), dpi=72)
    plt.clf()

    #get additional metrics
    pr, re, f1 = calculateMetrics()

    #normalize?
    if NORMALIZE_CONFMATRIX:
        global cmatrix
        cmatrix = np.around(cmatrix.astype('float') / cmatrix.sum(axis=1)[:, np.newaxis] * 100.0, decimals=1)

    #show matrix
    plt.imshow(cmatrix[:CONFMATRIX_MAX_CLASSES, :CONFMATRIX_MAX_CLASSES], interpolation='nearest', cmap=plt.cm.Blues)
    plt.title('Confusion Matrix\n' +
              RUN_NAME + ' - Epoch ' + str(epoch) +
              '\nTrain Samples: ' + str(len(TRAIN)) + ' Validation Samples: ' + str(len(VAL)) +              
              '\nmP: ' + str(np.mean(pr)) + ' mF1: ' + str( np.mean(f1)), fontsize=22)

    #tick marks
    tick_marks = np.arange(min(CONFMATRIX_MAX_CLASSES, NUM_CLASSES))
    plt.xticks(tick_marks, CLASSES[:CONFMATRIX_MAX_CLASSES], rotation=90)
    plt.yticks(tick_marks, CLASSES[:CONFMATRIX_MAX_CLASSES])

    #labels
    thresh = cmatrix.max() / 2.
    for i, j in itertools.product(range(min(CONFMATRIX_MAX_CLASSES, cmatrix.shape[0])), range(min(CONFMATRIX_MAX_CLASSES, cmatrix.shape[1]))):
        plt.text(j, i, cmatrix[i, j], 
                 horizontalalignment="center", verticalalignment="center",
                 color="white" if cmatrix[i, j] > thresh else "black", fontsize=8)

    #axes labels
    plt.tight_layout()
    plt.ylabel('Target label', fontsize=16)
    plt.xlabel('Predicted label', fontsize=16)

    #fontsize
    plt.rc('font', size=12)

    #save plot
    global cmcnt
    if not os.path.exists('confmatrix'):
        os.makedirs('confmatrix')
    plt.savefig('confmatrix/' + RUN_NAME + '_' + str(epoch) + '.png')
项目:GPS    作者:golsun    | 项目源码 | 文件源码
def plot_GPedge_mf(soln, GP_dir, opt, raw, path_save, rename, i_plot=None, title=None):
    plt.rc('font', **{'family':'Times New Roman'})

    if i_plot is None:
        sample_loc = float(opt['sample_loc'][0])
        sample_by = opt['sample_by']
        i_plot = find_i_plot(sample_loc, raw, sample_by)


    if opt['xscale'] == 'log':
        plot = plt.semilogx
        print 'using log as xscale = '+str(opt['xscale']) 
    else:
        plot = plt.plot
        print 'using linear as xscale = '+str(opt['xscale']) 

    i = 0
    for sp in GP_dir['member']:
        id_sp = soln.species_names.index(sp)
        mf = raw['mole_fraction'][i_plot, id_sp]
        #plot([0, mf], [-i,-i], color='k')
        #plt.text(0,-i, rename_species(sp, rename), horizontalalignment='right')

        plot(mf,-i, color='k', marker='o')
        plt.text(mf,-i, rename_species(sp, rename), horizontalalignment='right')


        i += 1


    T = raw['temperature']
    x = raw['axis0']
    tau_ign = find_tau_ign_raw(raw)

    if title is None:
        title = 'T = '+str(T[i_plot])+', axis0 = '+str(x[i_plot])
        if tau_ign is not None:
            title +=', norm_x = '+str(1.0*x[i_plot]/tau_ign)
        title+='\n'
    plt.title(title)



    plt.savefig(path_save)
    return True
项目:GASP-python    作者:henniggroup    | 项目源码 | 文件源码
def get_progress_plot(self):
        """
        Returns a plot of the best value versus the number of energy
        calculations, as a matplotlib plot object.
        """

        # set the font to Times, rendered with Latex
        plt.rc('font', **{'family': 'serif', 'serif': ['Times']})
        plt.rc('text', usetex=True)

        # parse the number of composition space endpoints
        endpoints_line = self.lines[0].split()
        endpoints = []
        for word in endpoints_line[::-1]:
            if word == 'endpoints:':
                break
            else:
                endpoints.append(word)
        num_endpoints = len(endpoints)

        if num_endpoints == 1:
            y_label = r'Best value (eV/atom)'
        elif num_endpoints == 2:
            y_label = r'Area of convex hull'
        else:
            y_label = r'Volume of convex hull'

        # parse the best values and numbers of energy calculations
        best_values = []
        num_calcs = []
        for i in range(4, len(self.lines)):
            line = self.lines[i].split()
            num_calcs.append(int(line[4]))
            best_values.append(line[5])

        # check for None best values
        none_indices = []
        for value in best_values:
            if value == 'None':
                none_indices.append(best_values.index(value))

        for index in none_indices:
            del best_values[index]
            del num_calcs[index]

        # make the plot
        plt.plot(num_calcs, best_values, color='blue', linewidth=2)
        plt.xlabel(r'Number of energy calculations', fontsize=22)
        plt.ylabel(y_label, fontsize=22)
        plt.tick_params(which='both', width=1, labelsize=18)
        plt.tick_params(which='major', length=8)
        plt.tick_params(which='minor', length=4)
        plt.xlim(xmin=0)
        plt.tight_layout()
        return plt
项目:GASP-python    作者:henniggroup    | 项目源码 | 文件源码
def get_phase_diagram_plot(self):
        """
        Returns a phase diagram plot, as a matplotlib plot object.
        """

        # set the font to Times, rendered with Latex
        plt.rc('font', **{'family': 'serif', 'serif': ['Times']})
        plt.rc('text', usetex=True)

        # parse the composition space endpoints
        endpoints_line = self.lines[0].split()
        endpoints = []
        for word in endpoints_line[::-1]:
            if word == 'endpoints:':
                break
            else:
                endpoints.append(Composition(word))

        if len(endpoints) < 2:
            print('There must be at least 2 endpoint compositions to make a '
                  'phase diagram.')
            quit()

        # parse the compositions and total energies of all the structures
        compositions = []
        total_energies = []
        for i in range(4, len(self.lines)):
            line = self.lines[i].split()
            compositions.append(Composition(line[1]))
            total_energies.append(float(line[2]))

        # make a list of PDEntries
        pdentries = []
        for i in range(len(compositions)):
            pdentries.append(PDEntry(compositions[i], total_energies[i]))

        # make a CompoundPhaseDiagram
        compound_pd = CompoundPhaseDiagram(pdentries, endpoints)

        # make a PhaseDiagramPlotter
        pd_plotter = PDPlotter(compound_pd, show_unstable=100)
        return pd_plotter.get_plot(label_unstable=False)
项目:appBBB    作者:rl-institut    | 项目源码 | 文件源码
def stack_plot(energysystem, reg, bus, date_from, date_to):
    """
    Creates a stack plot of the specified bus.
    """
    # initialize plot
    myplot = tpd.DataFramePlot(energy_system=energysystem)

    # get dictionary with color of each entity in plot
    if bus == 'elec':
        cdict = color_dict(reg)
    elif bus == 'dh':
        cdict = color_dict_dh(reg)

    # slice dataframe to prepare for plot function
    myplot.slice_unstacked(
        bus_uid="('bus', '" + reg + "', '" + bus + "')",
        type="input",
        date_from=date_from,
        date_to=date_to)
    myplot.color_from_dict(cdict)

    # set plot parameters
    fig = plt.figure(figsize=(40, 14))
    plt.rc('legend', **{'fontsize': 18})
    plt.rcParams.update({'font.size': 18})
    plt.style.use('grayscale')

    # plot bus
    handles, labels = myplot.io_plot(
        bus_uid="('bus', '" + reg + "', '" + bus + "')", 
        cdict=cdict, 
        line_kwa={'linewidth': 4},
        ax=fig.add_subplot(1, 1, 1),
        date_from=date_from,
        date_to=date_to,
        )
    myplot.ax.set_ylabel('Power in MW')
    myplot.ax.set_xlabel('Date')
    myplot.ax.set_title(bus+" bus")
    myplot.set_datetime_ticks(tick_distance=24, date_format='%d-%m-%Y')
    myplot.outside_legend(handles=handles, labels=labels)

    plt.show()
    return (fig)
项目:powerplantmatching    作者:FRESNA    | 项目源码 | 文件源码
def comparison_1dim(by='Country', include_WEPP=True, include_VRE=False,
                    year=2015, how='hbar', figsize=(7,5)):
    """
    Plots a horizontal bar chart with capacity on x-axis, ``by`` on y-axis.

    Parameters
    ----------
    by : string, defines how to group data
        Allowed values: 'Country' or 'Fueltype'

    """
    red_w_wepp, red_wo_wepp, wepp, statistics = gather_comparison_data(include_WEPP=include_WEPP,
                                                                       include_VRE=include_VRE,
                                                                       year=year)
    if include_WEPP:
        stats = lookup([red_w_wepp, red_wo_wepp, wepp, statistics],
                       keys=['Matched dataset w/ WEPP', 'Matched dataset w/o WEPP',
                             'WEPP only', 'Statistics OPSD'], by=by)/1000
    else:
        stats = lookup([red_wo_wepp, statistics],
                       keys=['Matched dataset w/o WEPP', 'Statistics OPSD'],
                       by=by)/1000

    if how == 'hbar':
        with sns.axes_style('darkgrid'):
            font={'size'   : 24}
            plt.rc('font', **font)
            fig, ax = plt.subplots(figsize=figsize)
            stats.plot.barh(ax=ax, stacked=False, colormap='jet')
            ax.set_xlabel('Installed Capacity [GW]')
            ax.yaxis.label.set_visible(False)
            #ax.set_facecolor('#d9d9d9')                  # gray background
            ax.set_axisbelow(True)                       # puts the grid behind the bars
            ax.grid(color='white', linestyle='dotted')   # adds white dotted grid
            ax.legend(loc='best')
            ax.invert_yaxis()
        return fig, ax
    if how == 'scatter':
        stats.loc[:, by] = stats.index.astype(str) #Needed for seaborne
        if len(stats.columns)-1 >= 3:
            g = sns.pairplot(stats, diag_kind='kde', hue=by, palette='Set2',
                             size=figsize[1], aspect=figsize[0]/figsize[1])
        else:
            g = sns.pairplot(stats, diag_kind='kde', hue=by, palette='Set2',
                             size=figsize[1], aspect=figsize[0]/figsize[1],
                             x_vars=stats.columns[0], y_vars=stats.columns[1])
        for i in range(0, len(g.axes)):
            for j in range(0, len(g.axes[0])):
                g.axes[i,j].set(xscale='log', yscale='log', xlim=(1,200), ylim=(1,200))
        return g.fig, g.axes
项目:hco-experiments    作者:zooniverse    | 项目源码 | 文件源码
def hypothesisDist(y, pred, threshold=0.5):

    # the raw predictions for actual garbage
    garbageHypothesis = pred[np.where(y == 0)[0]]
    realHypothesis = pred[np.where(y == 1)[0]]

    font = {"size"   : 26}
    plt.rc("font", **font)
    plt.rc("legend", fontsize=22)
    #plt.rc('text', usetex=True)
    plt.rc('font', family='serif')

    #plt.yticks([x for x in np.arange(500,3500,500)])
    bins = [x for x in np.arange(0,1.04,0.04)]

    real_counts, bins, patches = plt.hist(realHypothesis, bins=bins, alpha=1, \
                                          label="real", color="#FF0066", edgecolor="none")
    #plt.hist(realHypothesis, bins=bins, alpha=1, lw=5, color="k", histtype="step")
    #real_counts, bins, patches = plt.hist(realHypothesis, bins=bins, alpha=1, lw=4,\
    #                                      label="real", color="#FF0066", histtype="step")
    print(real_counts)
    garbage_counts, bins, patches = plt.hist(garbageHypothesis, bins=bins, alpha=1, \
                                             label="bogus", color="#66FF33", edgecolor="none")
    #plt.hist(garbageHypothesis, bins=bins, alpha=1,  lw=5, color="k", histtype="step")
    #garbage_counts, bins, patches = plt.hist(garbageHypothesis, bins=bins, alpha=1,  lw=4,\
    #                                         label="bogus", color="#66FF33", histtype="step")

    print(garbage_counts)
    # calculate where the real counts are less than the garbage counts.
    # these are to be overplotted for clarity

    try:
        real_overlap = list(np.where(np.array(real_counts) <= np.array(garbage_counts))[0])
        for i in range(len(real_overlap)):
            to_plot = [bins[real_overlap[i]], bins[real_overlap[i]+1]]
            plt.hist(realHypothesis, bins=to_plot, alpha=1, color="#FF0066", edgecolor="none")
            #plt.hist(realHypothesis, bins=to_plot, alpha=1, color="k", lw=5, histtype="step")
            #plt.hist(realHypothesis, bins=to_plot, alpha=1, color="#FF0066", lw=4, histtype="step")
    except IndexError:
        pass

    max = int(np.max(np.array([np.max(real_counts), np.max(garbage_counts)])))
    print(max)
    decisionBoundary = np.array([x for x in range(0,max,100)])

    if garbage_counts[0] != 0:
        plt.text(0.01, 0.1*garbage_counts[0], str(int(garbage_counts[0])), rotation="vertical", size=22)

    plt.plot(threshold*np.ones(np.shape(decisionBoundary)), decisionBoundary, \
             "k--", label="decision boundary=%.3f"%(threshold), linewidth=2.0)

    y_min = -0.02*int(plt.axis()[-1])
    y_max = plt.axis()[-1]
    plt.xlim(-0.015,1.015)
    plt.ylim(y_min,y_max)
    #plt.title(dataFile.split("/")[-1])
    plt.xlabel("Hypothesis")
    plt.ylabel("Frequency")
    leg = plt.legend(loc="upper center")
    leg.get_frame().set_alpha(0.5)
    plt.show()
项目:fui-kk    作者:fui    | 项目源码 | 文件源码
def main():
    """The main function for this program."""
    import sys
    import argparse

    # Arguments:
    parser = argparse.ArgumentParser()
    parser.add_argument('-o', action='store',
                        dest='output', default='.',
                        help='Sets OUTPUT destination (\'.\' is default)')
    parser.add_argument('-t', action='store',
                        dest='score_tree_path', default='.',
                        help='Traverse SCORE_TREE_PATH for score_overview.txt files')
    parser.add_argument('-d', action='store',
                        dest='filter_path', default=None,
                        help='Searches FILTER_PATH for courses to plot')
    parser.add_argument('-r', action='store',
                        dest='report', default=None,
                        help='rebuilds a REPORT with plots')
    parser.add_argument('-m', action='store_true',
                        dest='multiprocessing', default=None,
                        help='enable multiprocessing')

    # Extract arguments from command line arguments.
    result = parser.parse_args(sys.argv[1:])
    output = result.output
    score_tree = result.score_tree_path
    filter_path = result.filter_path
    rebuild_report = result.report

    # Traverse a directory tree and finds all score files. These are given to
    # extract_courses, and returns a dictionary of courses.
    courses = extract_courses(find_scores(score_tree))
    with open('course-data.js', 'w') as f:
        f.write(str(courses).upper())

    if filter_path:
        courses = {course: val for course, val in courses.items()
                   if course in find_courses(filter_path)
                   or course == 'average_score'}

    # Make it more LaTeX like (serif)
    plt.figure(figsize=(10, 5))
    plt.rc('font', family='serif')

    if result.multiprocessing:
        # Why not multiprocess?
        pool = multiprocessing.Pool(multiprocessing.cpu_count() * 4)

        # Maps over courses keys, using multithreading.
        pool.map(partial(plot_course, courses=courses, output=output), courses)
    else:
        for course in courses:
            plot_course(course, courses, output)

    if rebuild_report:
        rebuild_tex(rebuild_report, output)
项目:AcousticEventDetection    作者:kahst    | 项目源码 | 文件源码
def showConfusionMatrix(epoch):

    #new figure
    plt.figure(0, figsize=(35, 35), dpi=72)
    plt.clf()

    #get additional metrics
    pr, re, f1 = calculateMetrics()

    #normalize?
    if NORMALIZE_CONFMATRIX:
        global cmatrix
        cmatrix = np.around(cmatrix.astype('float') / cmatrix.sum(axis=1)[:, np.newaxis] * 100.0, decimals=1)

    #show matrix
    plt.imshow(cmatrix[:CONFMATRIX_MAX_CLASSES, :CONFMATRIX_MAX_CLASSES], interpolation='nearest', cmap=plt.cm.Blues)
    plt.title('Confusion Matrix\n' +
              RUN_NAME + ' - Epoch ' + str(epoch) +
              '\nTrain Samples: ' + str(len(TRAIN)) + ' Validation Samples: ' + str(len(VAL)) +              
              '\nmP: ' + str(np.mean(pr)) + ' mF1: ' + str( np.mean(f1)), fontsize=32)

    #tick marks
    tick_marks = np.arange(min(CONFMATRIX_MAX_CLASSES, NUM_CLASSES))
    plt.xticks(tick_marks, CLASSES[:CONFMATRIX_MAX_CLASSES], rotation=90)
    plt.yticks(tick_marks, CLASSES[:CONFMATRIX_MAX_CLASSES])

    #labels
    thresh = cmatrix.max() / 2.
    for i, j in itertools.product(range(min(CONFMATRIX_MAX_CLASSES, cmatrix.shape[0])), range(min(CONFMATRIX_MAX_CLASSES, cmatrix.shape[1]))):
        plt.text(j, i, cmatrix[i, j], 
                 horizontalalignment="center", verticalalignment="center",
                 color="white" if cmatrix[i, j] > thresh else "black", fontsize=32)

    #axes labels
    plt.tight_layout()
    plt.ylabel('Target label', fontsize=32)
    plt.xlabel('Predicted label', fontsize=32)

    #fontsize
    plt.rc('font', size=32)

    #save plot
    global cmcnt
    if not os.path.exists('confmatrix'):
        os.makedirs('confmatrix')
    plt.savefig('confmatrix/' + RUN_NAME + '_' + str(epoch) + '.png')