Python mpl_toolkits.mplot3d 模块,Axes3D() 实例源码

我们从Python开源项目中,提取了以下47个代码示例,用于说明如何使用mpl_toolkits.mplot3d.Axes3D()

项目:grove    作者:rigetticomputing    | 项目源码 | 文件源码
def test_visualization():
    ax = Axes3D(figure())
    # Without axis.
    ut.state_histogram(grove.tomography.operator_utils.GS, title="test")
    # With axis.
    ut.state_histogram(grove.tomography.operator_utils.GS, ax, "test")
    assert ax.get_title() == "test"

    ptX = grove.tomography.operator_utils.PAULI_BASIS.transfer_matrix(qt.to_super(
        grove.tomography.operator_utils.QX)).toarray()
    ax = Mock()
    with patch("matplotlib.pyplot.colorbar"):
        ut.plot_pauli_transfer_matrix(ptX, ax, grove.tomography.operator_utils.PAULI_BASIS.labels, "bla")
    assert ax.imshow.called
    assert ax.set_xlabel.called
    assert ax.set_ylabel.called
项目:BinarySTLfileReader    作者:sukhbinder    | 项目源码 | 文件源码
def ShowCooordsTopology(coords,topo):
    '''Plots the STL if coords and topology is given.  '''
    ax = a3d.Axes3D(plt.figure())

    xm,ym,zm=coords.max(axis=0)
    xmi,ymi,zmi =coords.min(axis=0)

    for nodes in topo:
        tri = a3d.art3d.Poly3DCollection([coords[nodes,:3]])
        tri.set_color(colors.rgb2hex([0.9,0.6,0.]))
        tri.set_edgecolor('k')
        ax.add_collection3d(tri)

    ax.set_xlim3d([xmi,xm])
    ax.set_ylim3d([ymi,ym])
    ax.set_zlim3d([zmi,zm])

    plt.show()
项目:BinarySTLfileReader    作者:sukhbinder    | 项目源码 | 文件源码
def ShowSTLFile(v1,v2,v3):
    '''Plots the STL files, give vertices v1,v2,v3'''
    ax = a3d.Axes3D(plt.figure())  

    xm,ym,zm=v1.max(axis=0)
    xmi,ymi,zmi =v2.min(axis=0)

    for i in range(v1.shape[0]):
        vtx=np.vstack((v1[i],v2[i],v3[i]))
        tri = a3d.art3d.Poly3DCollection([vtx])
        tri.set_color(colors.rgb2hex([0.9,0.6,0.]))
        tri.set_edgecolor('k')
        ax.add_collection3d(tri)

    ax.set_xlim3d([xmi,xm])
    ax.set_ylim3d([ymi,ym])
    ax.set_zlim3d([zmi,zm])

    plt.show()
项目:data_utilities    作者:fmv1992    | 项目源码 | 文件源码
def test_plot_3d(self):
        """Plot 3d test."""
        # Test 3d plots.
        self.assert_X_from_iterables(
            self.assertIsInstance,
            # TODO: error prone since colorbars can be added.
            (fig.get_axes()[0] for fig in self.figures_3d),
            itertools.repeat(Axes3D))
        # Test that there is just one axes per figure.
        for i, figure in enumerate(self.figures_3d):
            axes = figure.get_axes()
            if len(axes) != 1:
                # TODO: colorbar may add a second axes.
                pass
                raise ValueError(
                   "Axes has the wrong number of elements: {0} but "
                   "should be 1.".format(len(axes)))
项目:robotics1project    作者:pchorak    | 项目源码 | 文件源码
def display(self,angles):
        """
        Plots wireframe models of the Dobot and obstacles.
        """
        arm = DobotModel.get_mesh(angles)

        #fig = plt.figure()
        fig = plt.gcf()
        ax = Axes3D(fig)
        #plt.axis('equal')
        for Ta in arm:
            ax.plot(Ta[[0,1,2,0],0],Ta[[0,1,2,0],1],Ta[[0,1,2,0],2],'b')
        for To in self.obstacles:
            ax.plot(To[[0,1,2,0],0],To[[0,1,2,0],1],To[[0,1,2,0],2],'b')

        r_max = DobotModel.l1 + DobotModel.l2 + DobotModel.d

        plt.xlim([-np.ceil(r_max/np.sqrt(2)),r_max])
        plt.ylim([-r_max,r_max])
        ax.set_zlim(-150, 250)
        ax.view_init(elev=30.0, azim=60.0)
        plt.show()
        return fig
项目:ReinforcementLearning    作者:persistforever    | 项目源码 | 文件源码
def _plot_value_function(self, value_functions, n_iter):
        value_matrix = numpy.zeros((10, 10), dtype='float')
        for stateid in range(len(self.states)):
            dealer_showing, player_state = self.states[stateid].split('#')
            dealer_showing = 0 if dealer_showing == 'A' else int(dealer_showing)-1
            player_state = int(player_state)
            if player_state >= 12 and player_state < 22:
                value_matrix[player_state-12, dealer_showing] = value_functions[stateid]
        fig = plt.figure()
        ax = Axes3D(fig)
        Y, X = numpy.meshgrid(range(10), range(12,22))
        ax.plot_surface(Y, X, value_matrix, rstride=1, cstride=1, cmap='coolwarm')
        ax.set_title('value function in iteration %i' % n_iter)
        ax.set_xlabel('dealer showing')
        ax.set_ylabel('player sum')
        ax.set_zlabel('value function')
        plt.show()
项目:ReinforcementLearning    作者:persistforever    | 项目源码 | 文件源码
def _plot_value_function(self, value_functions, n_iter):
        value_matrix = numpy.zeros((10, 10), dtype='float')
        for stateid in range(len(self.states)):
            dealer_showing, player_state = self.states[stateid].split('#')
            dealer_showing = 0 if dealer_showing == 'A' else int(dealer_showing)-1
            player_state = int(player_state)
            if player_state >= 12 and player_state < 22:
                value_matrix[player_state-12, dealer_showing] = value_functions[stateid]
        fig = plt.figure()
        ax = Axes3D(fig)
        Y, X = numpy.meshgrid(range(10), range(12,22))
        ax.plot_surface(Y, X, value_matrix, rstride=1, cstride=1, cmap='coolwarm')
        ax.set_title('value function in iteration %i' % n_iter)
        ax.set_xlabel('dealer showing')
        ax.set_ylabel('player sum')
        ax.set_zlabel('value function')
        plt.show()
项目:und_Sophie_2016    作者:SophieTh    | 项目源码 | 文件源码
def plot_ring(self, Nb_pts=201):
        import matplotlib.pyplot as plt
        from mpl_toolkits.mplot3d import Axes3D
        if self.X is None or self.Y is None:
            raise Exception(" X and Y must be grid or a list for plotting")
        fig = plt.figure()
        ax = fig.gca(projection='3d')
        X = np.array([self.X[0]])
        Y = np.array([self.Y[0]])
        intensity = np.array([self.intensity[0]])
        ax.plot(X, Y, intensity, '^', label='ring number 0')
        ring_number = 1
        while (ring_number * Nb_pts < len(self.X)):
            X = self.X[(ring_number - 1) * Nb_pts + 1:ring_number * Nb_pts + 1]
            Y = self.Y[(ring_number - 1) * Nb_pts + 1:ring_number * Nb_pts + 1]
            intensity = self.intensity[(ring_number - 1) * Nb_pts + 1:ring_number * Nb_pts + 1]
            ax.plot(X, Y, intensity, label='ring number %d' % ring_number)
            ring_number += 1
        ax.set_xlabel("X")
        ax.set_ylabel('Y')
        ax.set_zlabel("itensity")
        ax.legend()
        plt.show()
项目:AIclass    作者:mttk    | 项目源码 | 文件源码
def plot_3d(X, y_actual, y_predicted=None):
    fig = plt.figure()

    if y_predicted is None:
        plt.title("Predicted vs actual function values")
    else: 
        plt.title("Approximated function samples")

    ax = Axes3D(fig)

    ax.view_init(elev=30, azim=70)

    scatter_actual = ax.scatter(X[:,0], X[:,1], y_actual, c='g', depthshade=False)
    if not y_predicted is None:
        scatter_predicted = ax.scatter(X[:,0], X[:,1], y_predicted, c='b', depthshade=False)

    if y_predicted is None:
        plt.legend((scatter_actual, scatter_predicted),
                ('Actual values', 'Predicted values'),
                scatterpoints = 1)

    plt.grid()
    plt.show()
项目:AIclass    作者:mttk    | 项目源码 | 文件源码
def plot_surface_3d(X, y_actual, NN):
    fig = plt.figure()
    plt.title("Predicted function with marked training samples")
    ax = Axes3D(fig)

    size = X.shape[0]

    ax.view_init(elev=30, azim=70)
    scatter_actual = ax.scatter(X[:,0], X[:,1], y_actual, c='g', depthshade=False)

    x0s = sorted(X[:,0])
    x1s = sorted(X[:,1])

    x0s, x1s = np.meshgrid(x0s, x1s)
    predicted_surface = np.zeros((size, size))

    for i in range(size):
        for j in range(size):
            predicted_surface[i,j] = NN.output(np.array([x0s[i,j], x1s[i,j]]))

    surf = ax.plot_surface(x0s, x1s, predicted_surface, rstride=2, cstride=2, linewidth=0, cmap=cm.coolwarm, alpha=0.5)

    plt.grid()
    plt.show()
项目:AIclass    作者:mttk    | 项目源码 | 文件源码
def plot_3d(X, y_actual, y_predicted=None):
    fig = plt.figure()

    if y_predicted is None:
        plt.title("Predicted vs actual function values")
    else: 
        plt.title("Approximated function samples")

    ax = Axes3D(fig)

    ax.view_init(elev=30, azim=70)

    scatter_actual = ax.scatter(X[:,0], X[:,1], y_actual, c='g', depthshade=False)
    if not y_predicted is None:
        scatter_predicted = ax.scatter(X[:,0], X[:,1], y_predicted, c='b', depthshade=False)

    if y_predicted is None:
        plt.legend((scatter_actual, scatter_predicted),
                ('Actual values', 'Predicted values'),
                scatterpoints = 1)

    plt.grid()
    plt.show()
项目:AIclass    作者:mttk    | 项目源码 | 文件源码
def plot_surface_3d(X, y_actual, NN):
    fig = plt.figure()
    plt.title("Predicted function with marked training samples")
    ax = Axes3D(fig)

    size = X.shape[0]

    ax.view_init(elev=30, azim=70)
    scatter_actual = ax.scatter(X[:,0], X[:,1], y_actual, c='g', depthshade=False)

    x0s = sorted(X[:,0])
    x1s = sorted(X[:,1])

    x0s, x1s = np.meshgrid(x0s, x1s)
    predicted_surface = np.zeros((size, size))

    for i in range(size):
        for j in range(size):
            predicted_surface[i,j] = NN.output(np.array([x0s[i,j], x1s[i,j]]))

    surf = ax.plot_surface(x0s, x1s, predicted_surface, rstride=2, cstride=2, linewidth=0, cmap=cm.coolwarm, alpha=0.5)

    plt.grid()
    plt.show()
项目:ML-note    作者:JasonK93    | 项目源码 | 文件源码
def plot_LDA(converted_X,y):
    '''
    plot the graph after transfer
    :param converted_X: train data after transfer
    :param y: train_value
    :return:  None
    '''
    from mpl_toolkits.mplot3d import Axes3D
    fig=plt.figure()
    ax=Axes3D(fig)
    colors='rgb'
    markers='o*s'
    for target,color,marker in zip([0,1,2],colors,markers):
        pos=(y==target).ravel()
        X=converted_X[pos,:]
        ax.scatter(X[:,0], X[:,1], X[:,2],color=color,marker=marker,
            label="Label {0}".format(target))
    ax.legend(loc="best")
    fig.suptitle("Iris After LDA")
    plt.show()
项目:EchoBurst    作者:TyJK    | 项目源码 | 文件源码
def plotModel3D(vectorFile, numClusters):
    # http://scikit-learn.org/stable/auto_examples/cluster/plot_cluster_iris.html

    model = Doc2Vec.load("Models\\" + vectorFile)
    docVecs = model.docvecs.doctag_syn0
    reduced_data = PCA(n_components=10).fit_transform(docVecs)
    kmeans = KMeans(init='k-means++', n_clusters=numClusters, n_init=10)

    fig = plt.figure(1, figsize=(10, 10))
    ax = Axes3D(fig, rect=[0, 0, .95, 1], elev=48, azim=134)
    kmeans.fit(reduced_data)
    labels = kmeans.labels_

    ax.scatter(reduced_data[:, 5], reduced_data[:, 2], reduced_data[:, 3], c=labels.astype(np.float))
    ax.w_xaxis.set_ticklabels([])
    ax.w_yaxis.set_ticklabels([])
    ax.w_zaxis.set_ticklabels([])
    # Plot the ground truth
    fig = plt.figure(1, figsize=(10, 10))
    plt.clf()
    ax = Axes3D(fig, rect=[0, 0, .95, 1], elev=48, azim=134)
    plt.cla()
    ax.scatter(reduced_data[:, 5], reduced_data[:, 2], reduced_data[:, 3], c=labels.astype(np.float))
    ax.w_xaxis.set_ticklabels([])
    ax.w_yaxis.set_ticklabels([])
    ax.w_zaxis.set_ticklabels([])
    plt.show()
项目:Parallel-SGD    作者:angadgill    | 项目源码 | 文件源码
def plot_figs(fig_num, elev, azim, X_train, clf):
    fig = plt.figure(fig_num, figsize=(4, 3))
    plt.clf()
    ax = Axes3D(fig, elev=elev, azim=azim)

    ax.scatter(X_train[:, 0], X_train[:, 1], y_train, c='k', marker='+')
    ax.plot_surface(np.array([[-.1, -.1], [.15, .15]]),
                    np.array([[-.1, .15], [-.1, .15]]),
                    clf.predict(np.array([[-.1, -.1, .15, .15],
                                          [-.1, .15, -.1, .15]]).T
                                ).reshape((2, 2)),
                    alpha=.5)
    ax.set_xlabel('X_1')
    ax.set_ylabel('X_2')
    ax.set_zlabel('Y')
    ax.w_xaxis.set_ticklabels([])
    ax.w_yaxis.set_ticklabels([])
    ax.w_zaxis.set_ticklabels([])

#Generate the three different figures from different views
项目:CuboctSTL    作者:figlax    | 项目源码 | 文件源码
def preview_mesh(*args):
    """
    This function plots numpy stl mesh objects entered into args. Note it will scale the preview plot based on the last mesh
    object entered.
    :param args: mesh objects to plot  ex.- preview_mesh(mesh1, mesh2, mesh3)
    :return:
    """
    print ("...preparing preview...")
    # Create a new plot
    figure = pyplot.figure()
    axes = mplot3d.Axes3D(figure)
    for mesh_obj in args:
        axes.add_collection3d(mplot3d.art3d.Poly3DCollection(mesh_obj.vectors))

    # Auto scale to the mesh size. Note it will choose the last mesh
    scale = mesh_obj.points.flatten(-1)
    axes.auto_scale_xyz(scale, scale, scale)

    # Show the plot to the screen
    pyplot.show()
项目:CuboctSTL    作者:figlax    | 项目源码 | 文件源码
def preview_mesh(*args):
    """
    This function plots numpy stl mesh objects entered into args. Note it will scale the preview plot based on the last mesh
    object entered.
    :param args: mesh objects to plot  ex.- preview_mesh(mesh1, mesh2, mesh3)
    :return:
    """
    print ("...preparing preview...")
    # Create a new plot
    figure = pyplot.figure()
    axes = mplot3d.Axes3D(figure)
    for mesh_obj in args:
        axes.add_collection3d(mplot3d.art3d.Poly3DCollection(mesh_obj.vectors))

    # Auto scale to the mesh size. Note it will choose the last mesh
    scale = mesh_obj.points.flatten(-1)
    axes.auto_scale_xyz(scale, scale, scale)

    # Show the plot to the screen
    pyplot.show()
项目:Default-Credit-Card-Prediction    作者:AlexPnt    | 项目源码 | 文件源码
def visualize_pca3D(X,y):
    """
    Visualize the first three principal components

    Keyword arguments:
    X -- The feature vectors
    y -- The target vector
    """
    pca = PCA(n_components = 3)
    principal_components = pca.fit_transform(X)

    fig = pylab.figure()
    ax = Axes3D(fig)
    # azm=30
    # ele=30
    # ax.view_init(azim=azm,elev=ele)

    palette = sea.color_palette()
    ax.scatter(principal_components[y==0, 0], principal_components[y==0, 1], principal_components[y==0, 2], label="Paid", alpha=0.5, 
                edgecolor='#262626', c=palette[1], linewidth=0.15)
    ax.scatter(principal_components[y==1, 0], principal_components[y==1, 1], principal_components[y==1, 2],label="Default", alpha=0.5, 
                edgecolor='#262626''', c=palette[2], linewidth=0.15)

    ax.legend()
    plt.show()
项目:Python_Study    作者:thsheep    | 项目源码 | 文件源码
def threeD():
    fig = figure()
    ax = Axes3D(fig)
    X = np.arange(-4, 4, 0.25)
    Y = np.arange(-4, 4, 0.25)
    X, Y = np.meshgrid(X, Y)
    R = np.sqrt(X ** 2 + Y ** 2)
    Z = np.sin(R)

    ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap='hot')
    show()
项目:phoebe2    作者:phoebe-project    | 项目源码 | 文件源码
def apply_limits(ax, pad=0.1):
    """
    apply the stored phoebe_limits to an axes, applying an additional padding

    :parameter ax:
    :parameter float pad: ratio of the range to apply as a padding (default: 0.1)
    """

    #try:
    if True:
        xlim = ax._phoebe_xlim
        ylim = ax._phoebe_ylim
        zlim = ax._phoebe_zlim
    #except AttributeError:
    #    return ax

    # initialize new lists for the padded limits.  We don't want to directly
    # edit xlim, ylim, zlim because we need padding based off the originals
    # and we don't want to have to worry about deepcopying issues
    xlim_pad = xlim[:]
    ylim_pad = ylim[:]
    zlim_pad = zlim[:]

    xlim_pad[0] = xlim[0] - pad*(xlim[1]-xlim[0])
    xlim_pad[1] = xlim[1] + pad*(xlim[1]-xlim[0])
    ylim_pad[0] = ylim[0] - pad*(ylim[1]-ylim[0])
    ylim_pad[1] = ylim[1] + pad*(ylim[1]-ylim[0])
    zlim_pad[0] = zlim[0] - pad*(zlim[1]-zlim[0])
    zlim_pad[1] = zlim[1] + pad*(zlim[1]-zlim[0])

    if isinstance(ax, Axes3D):
        ax.set_xlim3d(xlim_pad)
        ax.set_ylim3d(ylim_pad)
        ax.set_zlim3d(zlim_pad)
    else:
        ax.set_xlim(xlim_pad)
        ax.set_ylim(ylim_pad)

    return ax
项目:udacity-detecting-vehicles    作者:wonjunee    | 项目源码 | 文件源码
def plot3d(pixels, colors_rgb,
        axis_labels=list("RGB"), axis_limits=[(0, 255), (0, 255), (0, 255)]):
    """Plot pixels in 3D."""

    # Create figure and 3D axes
    fig = plt.figure(figsize=(8, 8))
    ax = Axes3D(fig)

    # Set axis limits
    ax.set_xlim(*axis_limits[0])
    ax.set_ylim(*axis_limits[1])
    ax.set_zlim(*axis_limits[2])

    # Set axis labels and sizes
    ax.tick_params(axis='both', which='major', labelsize=14, pad=8)
    ax.set_xlabel(axis_labels[0], fontsize=16, labelpad=16)
    ax.set_ylabel(axis_labels[1], fontsize=16, labelpad=16)
    ax.set_zlabel(axis_labels[2], fontsize=16, labelpad=16)

    # Plot pixel values with colors given in colors_rgb
    ax.scatter(
        pixels[:, :, 0].ravel(),
        pixels[:, :, 1].ravel(),
        pixels[:, :, 2].ravel(),
        c=colors_rgb.reshape((-1, 3)), edgecolors='none')

    return ax  # return Axes3D object for further manipulation


# Read a color image
项目:Clustering    作者:Ram81    | 项目源码 | 文件源码
def plot3D(data, output_labels_3d, centroids):
    '''
        Creating a 3d Plot of the dataset
    ''' 
    fig = plt.figure(3)
    ax = Axes3D(fig)

    for i in range(len(output_labels_3d)):
        if output_labels_3d[i] == 0:
            ax.scatter(data[i, 0], data[i, 1], data[i, 2], s = 20, c = 'k')
        elif output_labels_3d[i] == 1:
            ax.scatter(data[i, 0], data[i, 1], data[i, 2], s = 20, c = 'r')
        elif output_labels_3d[i] == 2:
            ax.scatter(data[i, 0], data[i, 1], data[i, 2], s = 20, c = 'b')
        elif output_labels_3d[i] == 3:
            ax.scatter(data[i, 0], data[i, 1], data[i, 2], s = 20, c = 'c')
        elif output_labels_3d[i] == 4:
            ax.scatter(data[i, 0], data[i, 1], data[i, 2], s = 20, c = 'g')
        elif output_labels_3d[i] == 5:
            ax.scatter(data[i, 0], data[i, 1], data[i, 2], s = 20, c = 'y')
        elif output_labels_3d[i] == 6:
            ax.scatter(data[i, 0], data[i, 1], data[i, 2], s = 20, c = 'm')
        elif output_labels_3d[i] == 7:
            ax.scatter(data[i, 0], data[i, 1], data[i, 2], s = 25, c = 'y')
        elif output_labels_3d[i] == 8:
            ax.scatter(data[i, 0], data[i, 1], data[i, 2], s = 25, c = 'b')
        elif output_labels_3d[i] == 9:
            ax.scatter(data[i, 0], data[i, 1], data[i, 2], s = 25, c = 'k')
        elif output_labels_3d[i] == 10:
            ax.scatter(data[i, 0], data[i, 1], data[i, 2], s = 25, c = 'm')
        elif output_labels_3d[i] == 11:
            ax.scatter(data[i, 0], data[i, 1], data[i, 2], s = 25, c = 'g')

    ax.scatter(centroids[:, 0], centroids[:, 1], centroids[:, 2], s = 150, c = 'r', marker = 'x', linewidth = 5)

    plt.show()

    return
项目:ReinforcementLearning    作者:persistforever    | 项目源码 | 文件源码
def _plot_value_function(self, value_function, n_iter):
        value_matrix = numpy.zeros((self.max_car+1, self.max_car+1), dtype='float')
        for stateid in range(len(self.states)):
            state = [int(t) for t in self.states[stateid].split('#')]
            value_matrix[state[0], state[1]] = value_function[stateid]
        fig = plt.figure()
        ax = Axes3D(fig)
        X, Y = numpy.meshgrid(range(self.max_car+1), range(self.max_car+1))
        ax.plot_surface(Y, X, value_matrix, rstride=1, cstride=1, cmap='coolwarm')
        ax.set_title('value function in iteration %i' % n_iter)
        ax.set_xlabel('#cars at A')
        ax.set_ylabel('#cars at B')
        ax.set_zlabel('value function')
        # plt.show()
        fig.savefig('experiments/value%i' % n_iter)
项目:tf-3dgan    作者:meetshah1995    | 项目源码 | 文件源码
def plotFromVF(vertices, faces):
    input_vec = mesh.Mesh(np.zeros(faces.shape[0], dtype=mesh.Mesh.dtype))
    for i, f in enumerate(faces):
        for j in range(3):
            input_vec.vectors[i][j] = vertices[f[j],:]
    figure = plt.figure()
    axes = mplot3d.Axes3D(figure)
    axes.add_collection3d(mplot3d.art3d.Poly3DCollection(input_vec.vectors))
    scale = input_vec.points.flatten(-1)
    axes.auto_scale_xyz(scale, scale, scale)
    plt.show()
项目:tf-3dgan    作者:meetshah1995    | 项目源码 | 文件源码
def plotFromVertices(vertices):
    figure = plt.figure()
    axes = mplot3d.Axes3D(figure)
    axes.scatter(vertices.T[0,:],vertices.T[1,:],vertices.T[2,:])
    plt.show()
项目:huaat_ml_dl    作者:ieee820    | 项目源码 | 文件源码
def draw3D(X, Y, Z, angle):
    fig = plt.figure(figsize=(15,7))
    ax = Axes3D(fig)
    ax.view_init(angle[0], angle[1])
    ax.plot_surface(X,Y,Z,rstride=1, cstride=1, cmap='rainbow')
    plt.imshow
项目:und_Sophie_2016    作者:SophieTh    | 项目源码 | 文件源码
def plot(self,title="",label=""):
        import matplotlib.pyplot as plt
        from mpl_toolkits.mplot3d import Axes3D

        if self.distance == None:
            zlabel = "Flux (phot/s/0.1%bw/mrad2)"
            xlabel = 'X [rad]'
            ylabel = 'Y [rad]'
        else:
            zlabel = "Flux (phot/s/0.1%bw/mm2)"
            xlabel = 'X [m]'
            ylabel = 'Y [m]'

        if self.X is None or self.Y is None:
            raise Exception(" X and Y must be array for plotting")
        if self.X.shape != self.Y.shape:
            raise Exception(" X and Y must have the same shape")
        fig = plt.figure()
        if len(self.X.shape) ==2 :
            ax = Axes3D(fig)
            ax.plot_surface(self.X, self.Y, self.intensity, rstride=1, cstride=1,cmap='hot_r')
        else :
            ax = fig.gca(projection='3d')
            ax.plot(self.X, self.Y, self.intensity, label=label)
            ax.legend()
        ax.set_xlabel(xlabel)
        ax.set_ylabel(ylabel)
        ax.set_zlabel(zlabel)

        plt.title(title)
        plt.show()
项目:hyper-engine    作者:maxim5    | 项目源码 | 文件源码
def plot_2d(self, f, a, b, grid_size=200):
    grid_x = np.linspace(a[0], b[0], num=grid_size).reshape((-1, 1))
    grid_y = np.linspace(a[1], b[1], num=grid_size).reshape((-1, 1))
    x, y = np.meshgrid(grid_x, grid_y)

    merged = np.stack([x.flatten(), y.flatten()])
    z = f(merged).reshape(x.shape)

    swap = np.swapaxes(merged, 0, 1)
    mu, sigma = self.utility.mean_and_std(swap)
    mu = mu.reshape(x.shape)
    sigma = sigma.reshape(x.shape)

    points = np.asarray(self.points)
    xs = points[:, 0]
    ys = points[:, 1]
    zs = f(np.swapaxes(points, 0, 1))

    fig = plt.figure()
    ax = Axes3D(fig)
    ax.plot_surface(x, y, z, color='black', label='f', alpha=0.7,
                    linewidth=0, antialiased=False)
    ax.plot_surface(x, y, mu, color='red', label='mu', alpha=0.5)
    ax.plot_surface(x, y, mu + sigma, color='blue', label='mu+sigma', alpha=0.3)
    ax.plot_surface(x, y, mu - sigma, color='blue', alpha=0.3)
    ax.scatter(xs, ys, zs, color='red', marker='o', s=100)
    # plt.legend()
    plt.show()
项目:computationalphysics_N2013301020050    作者:ShixingWang    | 项目源码 | 文件源码
def plot(self):
        x=np.linspace(-1,1,self.L)
        y=np.linspace(-1,1,self.L)
        X,Y=np.meshgrid(x,y)
        fig=plt.figure()
        ax=Axes3D(fig)
        ax.plot_surface(X, Y, self.V, rstride=5, cstride=5, cmap='hot')
项目:pySA    作者:kjzhang9    | 项目源码 | 文件源码
def run_draw():
    #init = -sys.maxsize # for maximun case

    targ = SimAnneal(target_text='max')
    init = -sys.maxsize # for maximun case
    #init = sys.maxsize # for minimun case
    xyRange = [[-2, 2], [-2, 2]]
    xRange = [[0, 10]]
    t_start = time()

    calculate = OptSolution(Markov_chain=1000, result=init, val_nd=[0,0])
    output = calculate.soulution(SA_newV=targ.newVar, SA_preV=targ.preVar, SA_juge=targ.juge, 
                                juge_text='max',ValueRange=xyRange, func=func2)
    t_end = time()
    #print(city_pos)
    print('Running %.4f seconds' %(t_end-t_start))

    # plot animation
    fig = plt.figure()
    ax = Axes3D(fig)
    xv = np.linspace(xyRange[0][0], xyRange[0][1], 200)
    yv = np.linspace(xyRange[1][0], xyRange[1][1], 200)
    xv, yv = np.meshgrid(xv, yv)
    zv = func2([xv, yv])
    ax.plot_surface(xv, yv, zv, rstride=1, cstride=1, cmap='GnBu', alpha=1)
    #dot = ax.scatter(0, 0, 0, 'ro')
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    x, y, z = output[0][0], output[0][1], output[1]
    ax.scatter(x, y, z, c='r', marker='o')

    plt.savefig('SA_min0.png')
    plt.show()
项目:gpucb    作者:tushuhei    | 项目源码 | 文件源码
def plot(self):
    fig = plt.figure()
    ax = Axes3D(fig)
    ax.plot_wireframe(self.meshgrid[0], self.meshgrid[1],
        self.mu.reshape(self.meshgrid[0].shape), alpha=0.5, color='g')
    ax.plot_wireframe(self.meshgrid[0], self.meshgrid[1],
        self.environment.sample(self.meshgrid), alpha=0.5, color='b')
    ax.scatter([x[0] for x in self.X], [x[1] for x in self.X], self.T, c='r',
        marker='o', alpha=1.0)
    plt.savefig('fig_%02d.png' % len(self.X))
项目:SoftSAR    作者:eduardosufan    | 项目源码 | 文件源码
def plot_trajectory_3D(traj):
    """
    Plot airplane trajectory in 3D.

    Parameters
    ----------
    traj: object (AirplaneTrajectory instance).
      AirplaneTrajectory instance.

    Returns
    -------
    -.
    """

    fig = plt.figure()
    fp = Axes3D(fig)

    fp.plot(traj.flight_x, traj.flight_y, traj.flight_z, label='Airplane trajectory')
    fp.legend()
    fp.set_xlabel('x position')
    fp.set_ylabel('y position')
    fp.set_zlabel('z position')

    figv = plt.figure()
    fv = figv.gca(projection='3d')
    fv.plot(traj.flight_vx, traj.flight_vy, traj.flight_vz, label='Airplane velocity')
    fv.set_xscale('linear')
    fv.set_yscale('linear')
    fv.set_zscale('linear')

    fv.legend()
    fv.set_xlabel('x velocity')
    fv.set_ylabel('y velocity')
    fv.set_zlabel('z velocity')
    plt.show()
项目:msd-genrecl    作者:Phonicavi    | 项目源码 | 文件源码
def PCA_n_plot():
    all_feas = get_all_feas();


    print 'Start PCA ...'
    pca = PCA(n_components = 3,whiten = False)
    new_feas = pca.fit_transform(np.array(all_feas))
    Fig = plt.figure()
    ax = Axes3D(Fig)
    print 'Start Ploting'
    ax.scatter(new_feas[:,0],new_feas[:,1],new_feas[:,2])
    plt.show()
项目:bnpy    作者:bnpy    | 项目源码 | 文件源码
def plotSequenceForRotatingState3D(degPerStep, Sigma, stationaryDim=0, T=1000):
    A = makeA_3DRotationMatrix(degPerStep, stationaryDim)
    Sigma = Sigma * np.eye(3)
    assert Sigma.shape == (3, 3)

    X = np.zeros((T, 3))
    X[0, :] = [1, 1, 1]
    for t in xrange(1, T):
        X[t] = np.random.multivariate_normal(np.dot(A, X[t - 1]), Sigma)

    ax = Axes3D(pylab.figure())
    pylab.plot(X[:, 0], X[:, 1], X[:, 2], '.')
    pylab.axis('equal')
项目:bnpy    作者:bnpy    | 项目源码 | 文件源码
def showEachSetOfStatesIn3D():
    ''' Make a 3D plot in separate figure for each of the 3 states in a "set"

        These three states just vary the speed of rotation and scale of noise,
        from slow and large to fast and smaller.
    '''
    from matplotlib import pylab
    from mpl_toolkits.mplot3d import Axes3D
    L = len(degPerSteps)
    for ii in xrange(L):
        plotSequenceForRotatingState3D(-1 * degPerSteps[ii], sigma2s[ii], 2)
项目:decoding_challenge_cortana_2016_3rd    作者:kingjr    | 项目源码 | 文件源码
def _plot_sensors(pos, colors, ch_names, title, show_names, show):
    """Helper function for plotting sensors."""
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D
    from .topomap import _check_outlines, _draw_outlines
    fig = plt.figure()

    if pos.shape[1] == 3:
        ax = Axes3D(fig)
        ax = fig.gca(projection='3d')
        ax.text(0, 0, 0, '', zorder=1)
        ax.scatter(pos[:, 0], pos[:, 1], pos[:, 2], picker=True, c=colors)
        ax.azim = 90
        ax.elev = 0
    else:
        ax = fig.add_subplot(111)
        ax.text(0, 0, '', zorder=1)
        ax.set_xticks([])
        ax.set_yticks([])
        fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=None,
                            hspace=None)
        pos, outlines = _check_outlines(pos, 'head')
        _draw_outlines(ax, outlines)
        ax.scatter(pos[:, 0], pos[:, 1], picker=True, c=colors)

    if show_names:
        for idx in range(len(pos)):
            this_pos = pos[idx]
            if pos.shape[1] == 3:
                ax.text(this_pos[0], this_pos[1], this_pos[2], ch_names[idx])
            else:
                ax.text(this_pos[0], this_pos[1], ch_names[idx])
    else:
        picker = partial(_onpick_sensor, fig=fig, ax=ax, pos=pos,
                         ch_names=ch_names)
        fig.canvas.mpl_connect('pick_event', picker)
    fig.suptitle(title)
    plt_show(show)
    return fig
项目:pwtools    作者:elcorto    | 项目源码 | 文件源码
def fig_ax3d(**kwds):
    fig = plt.figure(**kwds)
    try: 
        ax = fig.add_subplot(111, projection='3d')
    except:
        # mpl < 1.0.0
        ax = Axes3D(fig)
    return fig, ax
项目:Parallel-SGD    作者:angadgill    | 项目源码 | 文件源码
def plot_figs(fig_num, elev, azim):
    fig = plt.figure(fig_num, figsize=(4, 3))
    plt.clf()
    ax = Axes3D(fig, rect=[0, 0, .95, 1], elev=elev, azim=azim)

    ax.scatter(a[::10], b[::10], c[::10], c=density[::10], marker='+', alpha=.4)
    Y = np.c_[a, b, c]

    # Using SciPy's SVD, this would be:
    # _, pca_score, V = scipy.linalg.svd(Y, full_matrices=False)

    pca = PCA(n_components=3)
    pca.fit(Y)
    pca_score = pca.explained_variance_ratio_
    V = pca.components_

    x_pca_axis, y_pca_axis, z_pca_axis = V.T * pca_score / pca_score.min()

    x_pca_axis, y_pca_axis, z_pca_axis = 3 * V.T
    x_pca_plane = np.r_[x_pca_axis[:2], - x_pca_axis[1::-1]]
    y_pca_plane = np.r_[y_pca_axis[:2], - y_pca_axis[1::-1]]
    z_pca_plane = np.r_[z_pca_axis[:2], - z_pca_axis[1::-1]]
    x_pca_plane.shape = (2, 2)
    y_pca_plane.shape = (2, 2)
    z_pca_plane.shape = (2, 2)
    ax.plot_surface(x_pca_plane, y_pca_plane, z_pca_plane)
    ax.w_xaxis.set_ticklabels([])
    ax.w_yaxis.set_ticklabels([])
    ax.w_zaxis.set_ticklabels([])
项目:classify_dream_of_the_red_chamber    作者:MrQianJinSi    | 项目源码 | 文件源码
def scatters_in_3d(samples, is_labelled = False):
  # PCA ???2??????
  pca = PCA(n_components=3)
  reduced_data = pca.fit_transform(samples)

  fig = plt.figure()
  ax = Axes3D(fig, rect=[0, 0, .95, 1], elev=9, azim=-170)
  for c,  rng in [('r', (0, 80)), ('b', (80, 120))]:
    xs = reduced_data[rng[0]:rng[1], 0]
    ys = reduced_data[rng[0]:rng[1], 1]
    zs = reduced_data[rng[0]:rng[1], 2]
    ax.scatter(xs, ys, zs, c=c)


  ax.w_xaxis.set_ticklabels([])
  ax.w_yaxis.set_ticklabels([])
  ax.w_zaxis.set_ticklabels([])

  if is_labelled:
    for ix in np.arange(len(samples)):
      ax.text(reduced_data[ix, 0], reduced_data[ix, 1],reduced_data[ix, 2],
          str(ix+1), verticalalignment='center', fontsize=10)

  plt.show()

# ????????kNN?????
项目:livespin    作者:biocompibens    | 项目源码 | 文件源码
def hist2d(self, title, newfig = True):
        if newfig :
            fig1 = pylab.figure()
            ax = Axes3D(fig1)
        else:
            ax = Axes3D(newfig)
        dx, dy = np.indices(self.shape)
        ax.plot_wireframe(dx, dy, self.image, alpha = 0.2)
        pylab.title(title)
        return ax
项目:l1l2py    作者:slipguru    | 项目源码 | 文件源码
def kcv_errors(errors, range_x, range_y, label_x, label_y):
    r"""Plot a 3D error surface.

    Parameters
    ----------
    errors : (N, D) ndarray
        Error matrix.
    range_x : array_like of N values
        First axis values.
    range_y : array_like of D values
        Second axis values.
    label_x : str
        First axis label.
    label_y : str
        Second axis label.

    Examples
    --------
    >>> errors = numpy.empty((20, 10))
    >>> x = numpy.arange(20)
    >>> y = numpy.arange(10)
    >>> for i in range(20):
    ...     for j in range(10):
    ...         errors[i, j] = (x[i] * y[j])
    ...
    >>> kcv_errors(errors, x, y, 'x', 'y')
    >>> plt.show()

    """
    fig = plt.figure()
    ax = Axes3D(fig)

    x_vals, y_vals = np.meshgrid(range_x, range_y)
    x_idxs, y_idxs = np.meshgrid(np.arange(len(range_x)),
                                 np.arange(len(range_y)))

    ax.set_xlabel(label_x)
    ax.set_ylabel(label_y)
    ax.set_zlabel('$error$')

    ax.plot_surface(x_vals, y_vals, errors[x_idxs, y_idxs],
                    rstride=1, cstride=1, cmap=cm.jet)
项目:grove    作者:rigetticomputing    | 项目源码 | 文件源码
def state_histogram(rho, ax=None, title="", threshold=0.001):
    """
    Visualize a density matrix as a 3d bar plot with complex phase encoded
    as the bar color.

    This code is a modified version of
    `an equivalent function in qutip <http://qutip.org/docs/3.1.0/apidoc/functions.html#qutip.visualization.matrix_histogram_complex>`_
    which is released under the (New) BSD license.

    :param qutip.Qobj rho: The density matrix.
    :param Axes3D ax: The axes object.
    :param str title: The axes title.
    :param float threshold: (Optional) minimum magnitude of matrix elements. Values below this
    are hidden.
    :return: The axis
    :rtype: mpl_toolkits.mplot3d.Axes3D
    """
    rho_amps = rho.data.toarray().ravel()
    nqc = int(round(np.log2(rho.shape[0])))
    if ax is None:
        fig = plt.figure(figsize=(10, 6))
        ax = Axes3D(fig, azim=-35, elev=35)
    cmap = rigetti_4_color_cm
    norm = mpl.colors.Normalize(-np.pi, np.pi)
    colors = cmap(norm(np.angle(rho_amps)))
    dzs = abs(rho_amps)
    colors[:, 3] = 1.0 * (dzs > threshold)
    xs, ys = np.meshgrid(range(2 ** nqc), range(2 ** nqc))
    xs = xs.ravel()
    ys = ys.ravel()
    zs = np.zeros_like(xs)
    dxs = dys = np.ones_like(xs) * 0.8

    _ = ax.bar3d(xs, ys, zs, dxs, dys, dzs, color=colors)
    ax.set_xticks(np.arange(2 ** nqc) + .4)
    ax.set_xticklabels(basis_labels(nqc))
    ax.set_yticks(np.arange(2 ** nqc) + .4)
    ax.set_yticklabels(basis_labels(nqc))
    ax.set_zlim3d([0, 1])

    cax, kw = mpl.colorbar.make_axes(ax, shrink=.75, pad=.1)
    cb = mpl.colorbar.ColorbarBase(cax, cmap=cmap, norm=norm)
    cb.set_ticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    cb.set_ticklabels((r'$-\pi$', r'$-\pi/2$', r'$0$', r'$\pi/2$', r'$\pi$'))
    cb.set_label('arg')
    ax.view_init(azim=-55, elev=45)
    ax.set_title(title)
    return ax
项目:SwarmPackagePy    作者:SISDevelop    | 项目源码 | 文件源码
def animation3D(agents, function, lb, ub, sr=False):

    side = np.linspace(lb, ub, 45)
    X, Y = np.meshgrid(side, side)
    zs = np.array([function([x, y]) for x, y in zip(np.ravel(X), np.ravel(Y))])
    Z = zs.reshape(X.shape)

    fig = plt.figure()

    ax = Axes3D(fig)
    surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap='jet',
                           linewidth=0, antialiased=False)
    ax.set_xlim(lb, ub)
    ax.set_ylim(lb, ub)

    ax.zaxis.set_major_locator(LinearLocator(10))
    ax.zaxis.set_major_formatter(FormatStrFormatter('%.02f'))

    fig.colorbar(surf, shrink=0.5, aspect=5)

    iter = len(agents)
    n = len(agents[0])
    t = np.array([np.ones(n) * i for i in range(iter)]).flatten()
    b = []
    [[b.append(agent) for agent in epoch] for epoch in agents]
    c = [function(x) for x in b]
    a = np.asarray(b)
    df = pd.DataFrame({"time": t, "x": a[:, 0], "y": a[:, 1], "z": c})

    def update_graph(num):
        data = df[df['time'] == num]
        graph._offsets3d = (data.x, data.y, data.z)
        title.set_text(function.__name__ + " " * 45 + 'iteration: {}'.format(
            num))

    title = ax.set_title(function.__name__ + " " * 45 + 'iteration: 0')

    data = df[df['time'] == 0]
    graph = ax.scatter(data.x, data.y, data.z, color='black')

    ani = matplotlib.animation.FuncAnimation(fig, update_graph, iter,
                                             interval=50, blit=False)

    if sr:

        ani.save('result.mp4')

    plt.show()
项目:PyGeM    作者:mathLab    | 项目源码 | 文件源码
def plot(self, plot_file=None, save_fig=False):
        """
        Method to plot a file. If `plot_file` is not given it plots `self.shape`.

        :param string plot_file: the filename you want to plot.
        :param bool save_fig: a flag to save the figure in png or not. If True the
            plot is not shown.

        :return: figure: matlplotlib structure for the figure of the chosen geometry
        :rtype: matplotlib.pyplot.figure
        """
        if plot_file is None:
            shape = self.shape
            plot_file = self.infile
        else:
            shape = self.load_shape_from_file(plot_file)

        stl_writer = StlAPI_Writer()
        # Do not switch SetASCIIMode() from False to True.
        stl_writer.SetASCIIMode(False)
        stl_writer.Write(shape, 'aux_figure.stl')

        # Create a new plot
        figure = pyplot.figure()
        axes = mplot3d.Axes3D(figure)

        # Load the STL files and add the vectors to the plot
        stl_mesh = mesh.Mesh.from_file('aux_figure.stl')
        os.remove('aux_figure.stl')
        axes.add_collection3d(mplot3d.art3d.Poly3DCollection(stl_mesh.vectors / 1000))

        # Get the limits of the axis and center the geometry
        max_dim = np.array([\
            np.max(stl_mesh.vectors[:, :, 0]) / 1000,\
            np.max(stl_mesh.vectors[:, :, 1]) / 1000,\
            np.max(stl_mesh.vectors[:, :, 2]) / 1000])
        min_dim = np.array([\
            np.min(stl_mesh.vectors[:, :, 0]) / 1000,\
            np.min(stl_mesh.vectors[:, :, 1]) / 1000,\
            np.min(stl_mesh.vectors[:, :, 2]) / 1000])

        max_lenght = np.max(max_dim - min_dim)
        axes.set_xlim(\
            -.6 * max_lenght + (max_dim[0] + min_dim[0]) / 2,\
            .6 * max_lenght + (max_dim[0] + min_dim[0]) / 2)
        axes.set_ylim(\
            -.6 * max_lenght + (max_dim[1] + min_dim[1]) / 2,\
            .6 * max_lenght + (max_dim[1] + min_dim[1]) / 2)
        axes.set_zlim(\
            -.6 * max_lenght + (max_dim[2] + min_dim[2]) / 2,\
            .6 * max_lenght + (max_dim[2] + min_dim[2]) / 2)

        # Show the plot to the screen
        if not save_fig:
            pyplot.show()
        else:
            figure.savefig(plot_file.split('.')[0] + '.png')

        return figure
项目:deep-clustering    作者:zhr1201    | 项目源码 | 文件源码
def visualize(N_frame):
    with tf.Graph().as_default():
        # init the sample reader
        data_generator = AudioSampleReader(data_dir)
        # build the graph as the training script
        in_data = tf.placeholder(
            tf.float32, shape=[batch_size, FRAMES_PER_SAMPLE, NEFF])
        VAD_data = tf.placeholder(
            tf.float32, shape=[batch_size, FRAMES_PER_SAMPLE, NEFF])
        Y_data = tf.placeholder(
            tf.float32, shape=[batch_size, FRAMES_PER_SAMPLE, NEFF, 2])
        # init
        BiModel = Model(n_hidden, batch_size, False)
        # infer embedding
        embedding = BiModel.inference(in_data)
        saver = tf.train.Saver(tf.all_variables())
        sess = tf.Session()
        # restore a model
        saver.restore(sess, 'train/model.ckpt-68000')

        for step in range(N_frame):
            data_batch = data_generator.gen_next()
            if data_batch is None:
                break
            # concatenate the elements in sample dict to generate batch data
            in_data_np = np.concatenate(
                [np.reshape(item['Sample'], [1, FRAMES_PER_SAMPLE, NEFF])
                 for item in data_batch])
            VAD_data_np = np.concatenate(
                [np.reshape(item['VAD'], [1, FRAMES_PER_SAMPLE, NEFF])
                 for item in data_batch])
            embedding_np, = sess.run(
                [embedding],
                feed_dict={in_data: in_data_np,
                           VAD_data: VAD_data_np
                           })
            # only plot those embeddings whose VADs are active
            embedding_ac = [embedding_np[i, j, :]
                            for i, j in itertools.product(
                                range(FRAMES_PER_SAMPLE), range(NEFF))
                            if VAD_data_np[0, i, j] == 1]
            # ipdb.set_trace()

            kmean = KMeans(n_clusters=2, random_state=0).fit(embedding_ac)
            # visualization using 3 PCA
            pca_Data = PCA(n_components=3).fit_transform(embedding_ac)
            fig = plt.figure(1, figsize=(8, 6))
            ax = Axes3D(fig, elev=-150, azim=110)
            # ax.scatter(pca_Data[:, 0], pca_Data[:, 1], pca_Data[:, 2],
            #            c=kmean.labels_, cmap=plt.cm.Paired)
            ax.scatter(pca_Data[:, 0], pca_Data[:, 1], pca_Data[:, 2],
                       cmap=plt.cm.Paired)
            ax.set_title('Embedding visualization using the first 3 PCs')
            ax.set_xlabel('1st pc')
            ax.set_ylabel('2nd pc')
            ax.set_zlabel('3rd pc')
            plt.savefig('vis/' + str(step) + 'pca.jpg')
项目:pwtools    作者:elcorto    | 项目源码 | 文件源码
def plotlines3d(ax3d, x,y,z, *args, **kwargs):
    """Plot x-z curves stacked along y.

    Parameters
    ----------
    ax3d : Axes3D instance
    x : nd array
        1d (x-axis) or 2d (x-axes are the columns)
    y : 1d array        
    z : nd array with "y"-values
        1d : the same curve will be plotted len(y) times against x (1d) or
             x[:,i] (2d) 
        2d : each column z[:,i] will be plotted against x (1d) or each x[:,i]
             (2d)
    *args, **kwargs : additional args and keywords args passed to ax3d.plot()

    Returns
    -------
    ax3d

    Examples
    --------
    >>> x = linspace(0,5,100)
    >>> y = arange(1.0,5) # len(y) = 4
    >>> z = np.repeat(sin(x)[:,None], 4, axis=1)/y # make 2d 
    >>> fig,ax = fig_ax3d()
    >>> plotlines3d(ax, x, y, z)
    >>> show()
    """
    assert y.ndim == 1
    if z.ndim == 1:
        zz = np.repeat(z[:,None], len(y), axis=1)
    else:
        zz = z
    if x.ndim == 1:
        xx = np.repeat(x[:,None], zz.shape[1], axis=1)
    else:
        xx = x
    assert xx.shape == zz.shape
    assert len(y) == xx.shape[1] == zz.shape[1]
    for j in range(xx.shape[1]):
        ax3d.plot(xx[:,j], np.ones(xx.shape[0])*y[j], z[:,j], *args, **kwargs)
    return ax3d
项目:Parallel-SGD    作者:angadgill    | 项目源码 | 文件源码
def main():
    cal_housing = fetch_california_housing()

    # split 80/20 train-test
    X_train, X_test, y_train, y_test = train_test_split(cal_housing.data,
                                                        cal_housing.target,
                                                        test_size=0.2,
                                                        random_state=1)
    names = cal_housing.feature_names

    print("Training GBRT...", flush=True, end='')
    clf = GradientBoostingRegressor(n_estimators=100, max_depth=4,
                                    learning_rate=0.1, loss='huber',
                                    random_state=1)
    clf.fit(X_train, y_train)
    print(" done.")

    print('Convenience plot with ``partial_dependence_plots``')

    features = [0, 5, 1, 2, (5, 1)]
    fig, axs = plot_partial_dependence(clf, X_train, features,
                                       feature_names=names,
                                       n_jobs=3, grid_resolution=50)
    fig.suptitle('Partial dependence of house value on nonlocation features\n'
                 'for the California housing dataset')
    plt.subplots_adjust(top=0.9)  # tight_layout causes overlap with suptitle

    print('Custom 3d plot via ``partial_dependence``')
    fig = plt.figure()

    target_feature = (1, 5)
    pdp, axes = partial_dependence(clf, target_feature,
                                   X=X_train, grid_resolution=50)
    XX, YY = np.meshgrid(axes[0], axes[1])
    Z = pdp[0].reshape(list(map(np.size, axes))).T
    ax = Axes3D(fig)
    surf = ax.plot_surface(XX, YY, Z, rstride=1, cstride=1, cmap=plt.cm.BuPu)
    ax.set_xlabel(names[target_feature[0]])
    ax.set_ylabel(names[target_feature[1]])
    ax.set_zlabel('Partial dependence')
    #  pretty init view
    ax.view_init(elev=22, azim=122)
    plt.colorbar(surf)
    plt.suptitle('Partial dependence of house value on median age and '
                 'average occupancy')
    plt.subplots_adjust(top=0.9)

    plt.show()


# Needed on Windows because plot_partial_dependence uses multiprocessing