Python seaborn 模块,pairplot() 实例源码

我们从Python开源项目中,提取了以下10个代码示例,用于说明如何使用seaborn.pairplot()

项目:whereareyou    作者:futurice    | 项目源码 | 文件源码
def get_training_image():
    from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
    from matplotlib.figure import Figure
    import seaborn as sns
    import StringIO

    fig = Figure()
    df = pd.DataFrame.from_dict(get_flattened_training_data())
    features = [f for f in df.columns if f not in ['mac', 'location']]
    df = df.rename(columns=dict(zip(features, [POWER_SLAVE_PREFIX + f for f in features])))

    sns_plot = sns.pairplot(df, hue="location", vars=[POWER_SLAVE_PREFIX + f for f in features])
    png_output = StringIO.StringIO()
    sns_plot.savefig(png_output, format='png')

    canvas = FigureCanvas(fig)
    canvas.print_png(png_output)
    print png_output.getvalue()
    return
项目:python-machine-learning-book    作者:jeremyn    | 项目源码 | 文件源码
def visualize_housing_data(df):
    sns.set(style='whitegrid', context='notebook')
    cols = ['LSTAT', 'INDUS', 'NOX', 'RM', 'MEDV']

    sns.pairplot(df[cols], size=2.5)

    plt.show()

    correlation_matrix = np.corrcoef(df[cols].values.T)
    sns.set(font_scale=1.5)
    heatmap = sns.heatmap(
        correlation_matrix,
        cbar=True,
        annot=True,
        square=True,
        fmt='.2f',
        annot_kws={'size': 15},
        yticklabels=cols,
        xticklabels=cols,
    )

    plt.show()
项目:bio_corex    作者:gregversteeg    | 项目源码 | 文件源码
def plot_pairplots(data, labels, alpha, mis, column_label, topk=5, prefix='', focus=''):
    cmap = sns.cubehelix_palette(as_cmap=True, light=.9)
    plt.rcParams.update({'font.size': 32})
    m, nv = mis.shape
    for j in range(m):
        inds = np.where(np.logical_and(alpha[j] > 0, mis[j] > 0.))[0]
        inds = inds[np.argsort(- alpha[j, inds] * mis[j, inds])][:topk]
        if focus in column_label:
            ifocus = column_label.index(focus)
            if not ifocus in inds:
                inds = np.insert(inds, 0, ifocus)
        if len(inds) >= 2:
            plt.clf()
            subdata = data[:, inds]
            columns = [column_label[i] for i in inds]
            subdata = pd.DataFrame(data=subdata, columns=columns)

            try:
                sns.pairplot(subdata, kind="reg", diag_kind="kde", size=5, dropna=True)
                filename = '{}/pairplots_regress/group_num={}.pdf'.format(prefix, j)
                if not os.path.exists(os.path.dirname(filename)):
                    os.makedirs(os.path.dirname(filename))
                plt.suptitle("Latent factor {}".format(j), y=1.01)
                plt.savefig(filename, bbox_inches='tight')
                plt.clf()
            except:
                pass

            subdata['Latent factor'] = labels[:,j]
            try:
                sns.pairplot(subdata, kind="scatter", dropna=True, vars=subdata.columns.drop('Latent factor'), hue="Latent factor", diag_kind="kde", size=5)
                filename = '{}/pairplots/group_num={}.pdf'.format(prefix, j)
                if not os.path.exists(os.path.dirname(filename)):
                    os.makedirs(os.path.dirname(filename))
                plt.suptitle("Latent factor {}".format(j), y=1.01)
                plt.savefig(filename, bbox_inches='tight')
                plt.close('all')
            except:
                pass
项目:Data_Analysis    作者:crown-prince    | 项目源码 | 文件源码
def stock():
     #?????????????????, ?????????
    stock_list = {"zsyh":"600036","jsyh":"601939","szzs":"000001","pfyh":"600000","msyh":"600061"}
    for stock, code in stock_list.items():
        globals()[stock] = tsh.get_hist_data(code,start="2015-01-01",end="2016-04-16")
    stock_list2 = stock_list.keys()
    #print(stock_list2)
    sl = [globals()[st]["close"] for st in stock_list2]
    df_close = pd.concat(sl,axis=1,join='inner')
    df_close.columns = stock_list2
    #print(df_close)
    df_close.sort_index(ascending=True,inplace=True) #ascending ??????????????????
    pc_ret = df_close.pct_change() #????????????????
    print(pc_ret)
    make_end_line()
    print(pc_ret.mean())
    make_end_line()
    #????????????
    plt.show(sns.jointplot("zsyh","jsyh",pc_ret,kind="hex")) #?? ????????1?????????? 0????? -1????????
    plt.show(sns.jointplot("zsyh","jsyh",pc_ret,kind="scatter"))
    plt.show(sns.jointplot("zsyh","szzs",pc_ret,kind="scatter"))
    plt.show(sns.pairplot(pc_ret[["jsyh","zsyh","pfyh","msyh"]].dropna())) #??????????
    print(pc_ret.std()) #????????????????????????????
    make_end_line()
    rets = pc_ret.dropna()
    print(rets.mean())
    make_end_line()
    area = np.pi *20 #????
    plt.scatter(rets.mean(),rets.std())    #???rets?????????xy? 
    plt.xlabel("Expected Return")#????xy????
    plt.ylabel("Risk")
    for label,x,y in zip(rets.columns,rets.mean(),rets.std()):
        plt.annotate(
            label,
            xy = (x,y),xytext = (50,50),
            textcoords = "offset points",ha = "right",va = "bottom",
            arrowprops = dict(arrowstyle = "-",connectionstyle = "arc3,rad=-0.3"))
    plt.show()
项目:forward    作者:yajun0601    | 项目源码 | 文件源码
def loadDataSet(filename):
    df = pd.read_excel(filename,sheetname=[1], header=None, skiprows=1)[1]
    df = df.fillna(0)
    df[2] = df[2]/100000000  # ?????
#    zeros = df[df[0]==0]
#    df = df.drop(zeros.index,axis=0)
    df[2] = standard(df[2]) #?????

    df[3] = standard(df[3]) #????
    df[4] = standard(df[4]) #????
    df[6] = standard(df[6]) #??????
    df[9] = df[9].apply(map_01)  # ?????
    df[11] = df[11].apply(map_01) # ??????
    df[14] = standard(df[14]) #???????
    df[15] = standard(df[15].apply(map_rate)) #???????
    df[16] = standard(df[16].apply(map_sub_rate)) #?????????
    df[17] = standard(df[17]) #??????
    prov_coding,province_dict = transcoding(df[10]) # province

    enter_coding,enter_dict = transcoding(df[13]) # enterprise
    target = df[2]
    data = df[[3,4,6,9,11,14,15,16,17]]
    data = pd.concat([data,prov_coding],axis=1)
    data = pd.concat([data,enter_coding],axis=1)

    import seaborn as sns
#    sns.pairplot(df, x_vars=[3,4,6,9,11,14,15,16,17], y_vars=2, size=7, aspect=0.8, kind='reg')
    return mat(data),mat(target).T
项目:forward    作者:yajun0601    | 项目源码 | 文件源码
def loadDataSet(filename):
    sheet = 1
    df = pd.read_excel(filename,sheetname=[sheet], header=None, skiprows=1)[sheet]
    df = df.dropna(how='any',thresh=df.shape[1]/2) # drop those rows 
    df = df.dropna(how='any')
    df = df.fillna(0)
    df[2] = df[2]/100000000  # ?????
    df = df.sort_values(2).reset_index()
#    zeros = df[df[0]==0]
#    df = df.drop(zeros.index,axis=0)
    df[2] = (df[2]) #?????

    df[3] = standard(df[3]) #????
    df[4] = standard(df[4]) #????
    rate_type,rate_dict = transcoding(df[5]) # ????
    df[6] = standard(df[6]) #??????
#    market,market_dict = transcoding(df[7]) #????
    platform,platform_dict = transcoding(df[8]) #????
    df[9] = df[9].apply(map_01)  # ?????
    df[11] = df[11].apply(map_01) # ??????
    nature,nature_dict = transcoding(df[12]) #????
    df[14] = standard(df[14]) #???????
    print( df.groupby(15).size())
    df[15] = standard(df[15].apply(map_rate)) #???????
    print( df.groupby(15).size())
    df[16] = standard(df[16].apply(map_sub_rate)) #?????????
    df[17] = standard(df[17]) #??????

    target = df[2]
    data = df[[3,4,6,9,11,14,15,16,17]]
#    data = pd.concat([data,rate_type,platform,nature],axis=1)    

    import seaborn as sns
    sns.pairplot(df, x_vars=[3,17,4,14,15,16,6,9,11], y_vars=2, size=5, aspect=0.8, kind='reg')
#    sns.pairplot(df, vars=[2,4,14,15,17])
    return np.mat(data),np.mat(target).T
项目:kvae    作者:simonkamronn    | 项目源码 | 文件源码
def plot_auxiliary(all_vars, filename, table_size=4):
    # All variables need to be (batch_size, sequence_length, dimension)
    for i, a in enumerate(all_vars):
        if a.ndim == 2:
            all_vars[i] = np.expand_dims(a, 0)

    dim = all_vars[0].shape[-1]
    if dim == 2:
        f, ax = plt.subplots(table_size, table_size, sharex='col', sharey='row', figsize=[12, 12])
        idx = 0
        for x in range(table_size):
            for y in range(table_size):
                for a in all_vars:
                    # Loop over the batch dimension
                    ax[x, y].plot(a[idx, :, 0], a[idx, :, 1], linestyle='-', marker='o', markersize=3)
                    # Plot starting point of the trajectory
                    ax[x, y].plot(a[idx, 0, 0], a[idx, 0, 1], 'r.', ms=12)
                idx += 1
        # plt.show()
        plt.savefig(filename, format='png', bbox_inches='tight', dpi=80)
        plt.close()
    else:
        df_list = []
        for i, a in enumerate(all_vars):
            df = pd.DataFrame(all_vars[i].reshape(-1, dim))
            df['class'] = i
            df_list.append(df)

        df_all = pd.concat(df_list)
        sns_plot = sns.pairplot(df_all, hue="class", vars=range(dim))
        sns_plot.savefig(filename)
    plt.close()
项目:Default-Credit-Card-Prediction    作者:AlexPnt    | 项目源码 | 文件源码
def visualize_hist_pairplot(X,y,selected_feature1,selected_feature2,features,diag_kind):
    """
    Visualize the pairwise relationships (Histograms and Density Funcions) between classes and respective attributes

    Keyword arguments:
    X -- The feature vectors
    y -- The target vector
    selected_feature1 - First feature
    selected_feature1 - Second feature
    diag_kind -- Type of plot in the diagonal (Histogram or Density Function)
    """

    #create data
    joint_data=np.column_stack((X,y))
    column_names=features

    #create dataframe
    df=pd.DataFrame(data=joint_data,columns=column_names)

    #plot
    palette = sea.hls_palette()
    splot=sea.pairplot(df, hue="Y", palette={0:palette[2],1:palette[0]},vars=[selected_feature1,selected_feature2],diag_kind=diag_kind)
    splot.fig.suptitle('Pairwise relationship: '+selected_feature1+" vs "+selected_feature2)
    splot.set(xticklabels=[])
    # plt.subplots_adjust(right=0.94, top=0.94)

    #save fig
    output_dir = "img"
    save_fig(output_dir,'{}/{}_{}_hist_pairplot.png'.format(output_dir,selected_feature1,selected_feature2))
    # plt.show()
项目:AlphaPy    作者:ScottFreeLLC    | 项目源码 | 文件源码
def plot_scatter(df, features, target, tag='eda', directory=None):
    r"""Plot a scatterplot matrix, also known as a pair plot.

    Parameters
    ----------
    df : pandas.DataFrame
        The dataframe containing the features.
    features: list of str
        The features to compare in the scatterplot.
    target : str
        The target variable for contrast.
    tag : str
        Unique identifier for the plot.
    directory : str, optional
        The full specification of the plot location.

    Returns
    -------
    None : None.

    References
    ----------

    https://seaborn.pydata.org/examples/scatterplot_matrix.html

    """

    logger.info("Generating Scatter Plot")

    # Get the feature subset

    features.append(target)
    df = df[features]

    # Generate the pair plot

    sns.set()
    sns_plot = sns.pairplot(df, hue=target)

    # Save the plot
    write_plot('seaborn', sns_plot, 'scatter_plot', tag, directory)


#
# Function plot_facet_grid
#
项目:powerplantmatching    作者:FRESNA    | 项目源码 | 文件源码
def comparison_1dim(by='Country', include_WEPP=True, include_VRE=False,
                    year=2015, how='hbar', figsize=(7,5)):
    """
    Plots a horizontal bar chart with capacity on x-axis, ``by`` on y-axis.

    Parameters
    ----------
    by : string, defines how to group data
        Allowed values: 'Country' or 'Fueltype'

    """
    red_w_wepp, red_wo_wepp, wepp, statistics = gather_comparison_data(include_WEPP=include_WEPP,
                                                                       include_VRE=include_VRE,
                                                                       year=year)
    if include_WEPP:
        stats = lookup([red_w_wepp, red_wo_wepp, wepp, statistics],
                       keys=['Matched dataset w/ WEPP', 'Matched dataset w/o WEPP',
                             'WEPP only', 'Statistics OPSD'], by=by)/1000
    else:
        stats = lookup([red_wo_wepp, statistics],
                       keys=['Matched dataset w/o WEPP', 'Statistics OPSD'],
                       by=by)/1000

    if how == 'hbar':
        with sns.axes_style('darkgrid'):
            font={'size'   : 24}
            plt.rc('font', **font)
            fig, ax = plt.subplots(figsize=figsize)
            stats.plot.barh(ax=ax, stacked=False, colormap='jet')
            ax.set_xlabel('Installed Capacity [GW]')
            ax.yaxis.label.set_visible(False)
            #ax.set_facecolor('#d9d9d9')                  # gray background
            ax.set_axisbelow(True)                       # puts the grid behind the bars
            ax.grid(color='white', linestyle='dotted')   # adds white dotted grid
            ax.legend(loc='best')
            ax.invert_yaxis()
        return fig, ax
    if how == 'scatter':
        stats.loc[:, by] = stats.index.astype(str) #Needed for seaborne
        if len(stats.columns)-1 >= 3:
            g = sns.pairplot(stats, diag_kind='kde', hue=by, palette='Set2',
                             size=figsize[1], aspect=figsize[0]/figsize[1])
        else:
            g = sns.pairplot(stats, diag_kind='kde', hue=by, palette='Set2',
                             size=figsize[1], aspect=figsize[0]/figsize[1],
                             x_vars=stats.columns[0], y_vars=stats.columns[1])
        for i in range(0, len(g.axes)):
            for j in range(0, len(g.axes[0])):
                g.axes[i,j].set(xscale='log', yscale='log', xlim=(1,200), ylim=(1,200))
        return g.fig, g.axes