Python matplotlib.pyplot 模块,Axes() 实例源码

我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用matplotlib.pyplot.Axes()

项目:structured-output-ae    作者:sbelharbi    | 项目源码 | 文件源码
def debug_plot_over_img(self, img, x, y, bb_d, bb_gt):
        """Plot the landmarks over the image with the bbox."""
        plt.close("all")
        fig = plt.figure()  # , figsize=(15, 10.8), dpi=200
        ax = plt.Axes(fig, [0., 0., 1., 1.])
        ax.set_axis_off()
        ax.imshow(img, aspect="auto", cmap='Greys_r')
        ax.scatter(x, y, s=10, color='r')
        rect1 = patches.Rectangle(
            (bb_d[0], bb_d[1]), bb_d[2]-bb_d[0], bb_d[3]-bb_d[1],
            linewidth=1, edgecolor='r', facecolor='none')
        ax.add_patch(rect1)
        rect2 = patches.Rectangle(
            (bb_gt[0], bb_gt[1]), bb_gt[2]-bb_gt[0], bb_gt[3]-bb_gt[1],
            linewidth=1, edgecolor='b', facecolor='none')
        ax.add_patch(rect2)
        fig.add_axes(ax)

        return fig
项目:structured-output-ae    作者:sbelharbi    | 项目源码 | 文件源码
def plot_over_img(self, img, x, y, x_pr, y_pr, bb_gt):
        """Plot the landmarks over the image with the bbox."""
        plt.close("all")
        fig = plt.figure(frameon=False)  # , figsize=(15, 10.8), dpi=200
        ax = plt.Axes(fig, [0., 0., 1., 1.])
        ax.set_axis_off()
        ax.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), aspect="auto")
        ax.scatter(x, y, s=10, color='r')
        ax.scatter(x_pr, y_pr, s=10, color='g')
        rect = patches.Rectangle(
            (bb_gt[0], bb_gt[1]), bb_gt[2]-bb_gt[0], bb_gt[3]-bb_gt[1],
            linewidth=1, edgecolor='b', facecolor='none')
        ax.add_patch(rect)
        fig.add_axes(ax)

        return fig
项目:SFBIStats    作者:royludo    | 项目源码 | 文件源码
def create_wordcloud(corpus, output, stopword_dict):
    lex_dic = build_lex_dic(corpus, stopword_dict=stopword_dict)
    total_words = get_total_words(lex_dic)
    ordered_freq_list = build_freq_list(lex_dic, total_words)

    fig = plt.figure(figsize=(10, 8), frameon=False)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)
    wordcloud = WordCloud(width=1000, height=800, max_words=100, background_color='white',
                          relative_scaling=0.7, random_state=15, prefer_horizontal=0.5).generate_from_frequencies(
        ordered_freq_list[0:100])
    wordcloud.recolor(random_state=42, color_func=my_color_func)

    ax.imshow(wordcloud)
    fig.savefig(output, facecolor='white')
项目:physt    作者:janpipek    | 项目源码 | 文件源码
def _add_labels(ax, h, kwargs):
    """Add axis and plot labels.

    Parameters
    ----------
    ax : plt.Axes
    h : Histogram1D or Histogram2D
    kwargs: dict
    """
    title = kwargs.pop("title", h.title)
    xlabel = kwargs.pop("xlabel", h.axis_names[0])
    ylabel = kwargs.pop("ylabel", h.axis_names[1] if len(h.axis_names) == 2 else None)

    if title:
        ax.set_title(title)
    if xlabel:
        ax.set_xlabel(xlabel)
    if ylabel:
        ax.set_ylabel(ylabel)
    ax.get_figure().tight_layout()
项目:physt    作者:janpipek    | 项目源码 | 文件源码
def _add_values(ax, h1, data, value_format=lambda x: x):
    """Show values next to each bin in a 1D plot.

    Parameters
    ----------
    ax : plt.Axes
    h1 : physt.histogram1d.Histogram1D
    data : array_like
        The values to be displayed

    # TODO: Add some formatting
    """
    if value_format is None:
        value_format = ""
    if isinstance(value_format, str):
        format_str = "{0:" + value_format + "}"
        value_format = lambda x: format_str.format(x)
    for x, y in zip(h1.bin_centers, data):
        ax.text(x, y, str(value_format(y)), ha='center', va='bottom', clip_on=True)
项目:physt    作者:janpipek    | 项目源码 | 文件源码
def _add_stats_box(h1, ax):
    """Insert a small legend-like box with statistical information.

    Parameters
    ----------
    ax : plt.Axes
        Axes to draw it into
    h1 : physt.histogram1d.Histogram1D
        Histogram with valid statistics information

    Note
    ----
    Very basic implementation.
    """

    # place a text box in upper left in axes coords
    text = "Total: {0}\nMean: {1:.2f}\nStd.dev: {2:.2f}".format(
        h1.total, h1.mean(), h1.std())
    ax.text(0.05, 0.95, text, transform=ax.transAxes,
            verticalalignment='top', horizontalalignment='left')
项目:physt    作者:janpipek    | 项目源码 | 文件源码
def _add_ticks(ax, h1, kwargs):
    """Customize ticks for an axis (1D histogram).

    Parameters
    ----------
    ax : plt.Axes
    h1 : physt.histogram1d.Histogram1D
    ticks: {"center", "edge"}, optional
    """
    ticks = kwargs.pop("ticks", None)
    if not ticks:
        return
    elif ticks == "center":
        ax.set_xticks(h1.bin_centers)
    elif ticks == "edge":
        ax.set_xticks(h1.bin_left_edges)
项目:elfi    作者:elfi-dev    | 项目源码 | 文件源码
def plot_marginals(self, selector=None, bins=20, axes=None, **kwargs):
        """Plot marginal distributions for parameters.

        Parameters
        ----------
        selector : iterable of ints or strings, optional
            Indices or keys to use from samples. Default to all.
        bins : int, optional
            Number of bins in histograms.
        axes : one or an iterable of plt.Axes, optional

        Returns
        -------
        axes : np.array of plt.Axes

        """
        return vis.plot_marginals(self.samples, selector, bins, axes, **kwargs)
项目:elfi    作者:elfi-dev    | 项目源码 | 文件源码
def plot_marginals(self, selector=None, bins=20, axes=None, all=False, **kwargs):
        """Plot marginal distributions for parameters for all populations.

        Parameters
        ----------
        selector : iterable of ints or strings, optional
            Indices or keys to use from samples. Default to all.
        bins : int, optional
            Number of bins in histograms.
        axes : one or an iterable of plt.Axes, optional
        all : bool, optional
            Plot the marginals of all populations

        """
        if all is False:
            super(SmcSample, self).plot_marginals()
            return

        fontsize = kwargs.pop('fontsize', 13)
        for i, pop in enumerate(self.populations):
            pop.plot_marginals(selector=selector, bins=bins, axes=axes)
            plt.suptitle("Population {}".format(i), fontsize=fontsize)
项目:elfi    作者:elfi-dev    | 项目源码 | 文件源码
def plot_pairs(self, selector=None, bins=20, axes=None, all=False, **kwargs):
        """Plot pairwise relationships as a matrix with marginals on the diagonal.

        The y-axis of marginal histograms are scaled.

        Parameters
        ----------
        selector : iterable of ints or strings, optional
            Indices or keys to use from samples. Default to all.
        bins : int, optional
            Number of bins in histograms.
        axes : one or an iterable of plt.Axes, optional
        all : bool, optional
            Plot for all populations

        """
        if all is False:
            super(SmcSample, self).plot_marginals()
            return

        fontsize = kwargs.pop('fontsize', 13)
        for i, pop in enumerate(self.populations):
            pop.plot_pairs(selector=selector, bins=bins, axes=axes)
            plt.suptitle("Population {}".format(i), fontsize=fontsize)
项目:workspace    作者:nojima    | 项目源码 | 文件源码
def visualize_frequent_words(vectors_2d: np.ndarray, dataset: DataSet, k: int, ax: plt.Axes = None) -> None:
    word_ids, counts = np.unique(dataset.data, return_counts=True)

    indices = np.argpartition(-counts, k)[:k]
    frequent_word_ids = word_ids[indices]

    if ax is None:
        fig, ax = plt.subplots(figsize=(13, 13))
    else:
        fig = None

    vectors_2d = vectors_2d[frequent_word_ids]

    ax.scatter(vectors_2d[:, 0], vectors_2d[:, 1], s=2, alpha=0.25)
    for i, id in enumerate(frequent_word_ids):
        ax.annotate(dataset.vocabulary.to_word(id), (vectors_2d[i, 0], vectors_2d[i, 1]))

    if fig is not None:
        fig.tight_layout()
        fig.show()
项目:kite    作者:pyrocko    | 项目源码 | 文件源码
def plot(self, **kwargs):
        """Plot current quadtree

        :param axes: Axes instance to plot in, defaults to None
        :type axes: [:py:class:`matplotlib.Axes`], optional
        :param figure: Figure instance to plot in, defaults to None
        :type figure: [:py:class:`matplotlib.Figure`], optional
        :param **kwargs: kwargs are passed into `plt.imshow`
        :type **kwargs: dict
        """
        self._initImagePlot(**kwargs)
        self.data = self._quadtree.leaf_matrix_means
        self.title = 'Quadtree Means'

        self._addInfoText()

        if self._show_plt:
            plt.show()
项目:chi    作者:rmst    | 项目源码 | 文件源码
def __init__(self, axes: plt.Axes, timesteps=20, limits=None, auto_limit=6, title="", legend=()):
        self.auto_limit = auto_limit
        self.legend = legend
        self.x = []
        self.timesteps = timesteps
        self.limits = limits
        self.lines = []
        self.ax = axes
        self.ax.set_title(title)
        self.ax.set_xlim(-self.timesteps, 0)
        self.mean = 0
        self.var = 1
        self.y = []
        self.reset()
项目:osm_wpt    作者:krisanselmo    | 项目源码 | 文件源码
def plot_gpx_route(lon, lat, title):
    fig = plt.figure(facecolor='0.05')
    ax = plt.Axes(fig, [0., 0., 1., 1.], )
    ax.set_aspect(1.2)
    ax.set_axis_off()
    ax.set_title(title, color='white', fontsize=15)
    fig.add_axes(ax)
    plt.plot(lon, lat, '+-', color='red', lw=1, alpha=1)
    plt.hold(True)
    return plt
项目:adagan    作者:tolstikhin    | 项目源码 | 文件源码
def save_pic(pic, path, exp):
    if len(pic.shape) == 4:
        pic = pic[0]
    height = pic.shape[0]
    width = pic.shape[1]
    fig = plt.figure(frameon=False, figsize=(width, height))#, dpi=1)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)
    if exp.symmetrize:
        pic = (pic + 1.) / 2.
    if exp.dataset == 'mnist':
        pic = pic[:, :, 0]
        pic = 1. - pic
    if exp.dataset == 'mnist':
        ax.imshow(pic, cmap='Greys', interpolation='none')
    else:
        ax.imshow(pic, interpolation='none')
    fig.savefig(path, dpi=1, format='png')
    plt.close()
    # if exp.dataset == 'mnist':
    #     pic = pic[:, :, 0]
    #     pic = 1. - pic
    #     ax = plt.imshow(pic, cmap='Greys', interpolation='none')
    # else:
    #     ax = plt.imshow(pic, interpolation='none')
    # ax.axes.get_xaxis().set_ticks([])
    # ax.axes.get_yaxis().set_ticks([])
    # ax.axes.set_xlim([0, width])
    # ax.axes.set_ylim([height, 0])
    # ax.axes.set_aspect(1)
    # fig.savefig(path, format='png')
    # plt.close()
项目:physt    作者:janpipek    | 项目源码 | 文件源码
def line(h1, ax, errors=False, **kwargs):
    """Line plot of 1D histogram.

    Parameters
    ----------
    h1 : Histogram1D
    errors : bool
        Whether to draw error bars.

    Returns
    -------
    plt.Axes
    """
    show_stats = kwargs.pop("show_stats", False)
    show_values = kwargs.pop("show_values", False)
    density = kwargs.pop("density", False)
    cumulative = kwargs.pop("cumulative", False)
    value_format = kwargs.pop("value_format", None)

    data = get_data(h1, cumulative=cumulative, density=density)
    _apply_xy_lims(ax, h1, data, kwargs)
    _add_ticks(ax, h1, kwargs)
    _add_labels(ax, h1, kwargs)

    if errors:
        err_data = get_err_data(h1, cumulative=cumulative, density=density)
        ax.errorbar(h1.bin_centers, data, yerr=err_data, fmt=kwargs.pop(
            "fmt", "-"), ecolor=kwargs.pop("ecolor", "black"), **kwargs)
    else:
        ax.plot(h1.bin_centers, data, **kwargs)

    if show_stats:
        _add_stats_box(h1, ax)
    if show_values:
        _add_values(ax, h1, data, value_format=value_format)
    return ax
项目:physt    作者:janpipek    | 项目源码 | 文件源码
def fill(h1, ax, **kwargs):
    """Fill plot of 1D histogram.

    Parameters
    ----------
    h1 : Histogram1D

    Returns
    -------
    plt.Axes
    """
    show_stats = kwargs.pop("show_stats", False)
    # show_values = kwargs.pop("show_values", False)
    density = kwargs.pop("density", False)
    cumulative = kwargs.pop("cumulative", False)

    data = get_data(h1, cumulative=cumulative, density=density)
    _apply_xy_lims(ax, h1, data, kwargs)
    _add_ticks(ax, h1, kwargs)

    ax.fill_between(h1.bin_centers, 0, data, **kwargs)

    if show_stats:
        _add_stats_box(h1, ax)
    # if show_values:
    #     _add_values(ax, h1, data)
    return ax
项目:physt    作者:janpipek    | 项目源码 | 文件源码
def bar3d(h2, ax, **kwargs):
    """Plot of 2D histograms as 3D boxes.

    Parameters
    ----------
    h2 : Histogram2D

    Returns
    -------
    plt.Axes
    """
    density = kwargs.pop("density", False)
    data = get_data(h2, cumulative=False, flatten=True, density=density)
    # transformed = transform_data(data, kwargs)

    if "cmap" in kwargs:
        cmap = _get_cmap(kwargs)
        _, cmap_data = _get_cmap_data(data, kwargs)
        colors = cmap(cmap_data)
    else:
        colors = kwargs.pop("color", "blue")

    xpos, ypos = (arr.flatten() for arr in h2.get_bin_centers())
    zpos = np.zeros_like(ypos)
    dx, dy = (arr.flatten() for arr in h2.get_bin_widths())

    _add_labels(ax, h2, kwargs)
    ax.bar3d(xpos, ypos, zpos, dx, dy, data, color=colors, **kwargs)
    ax.set_zlabel("density" if density else "frequency")

    return ax
项目:physt    作者:janpipek    | 项目源码 | 文件源码
def polar_map(hist, ax, show_zero=True, **kwargs):
    """Polar map of polar histograms.

    Similar to map, but supports less parameters.

    Returns
    -------
    plt.Axes
    """
    data = get_data(hist, cumulative=False, flatten=True,
                    density=kwargs.pop("density", False))
    # transformed = transform_data(data, kwargs)

    cmap = _get_cmap(kwargs)
    norm, cmap_data = _get_cmap_data(data, kwargs)
    colors = cmap(cmap_data)

    rpos, phipos = (arr.flatten() for arr in hist.get_bin_left_edges())
    dr, dphi = (arr.flatten() for arr in hist.get_bin_widths())
    rmax, _ = (arr.flatten() for arr in hist.get_bin_right_edges())

    bar_args = {}
    if "zorder" in kwargs:
        bar_args["zorder"] = kwargs.pop("zorder")

    alphas = _get_alpha_data(cmap_data, kwargs)
    if np.isscalar(alphas):
        alphas = np.ones_like(data) * alphas

    for i in range(len(rpos)):
        if data[i] > 0 or show_zero:
            bin_color = colors[i]
            # TODO: align = "edge"
            bars = ax.bar(phipos[i], dr[i], width=dphi[i], bottom=rpos[i], color=bin_color,
                          edgecolor=kwargs.get("grid_color", cmap(0.5)), lw=kwargs.get("lw", 0.5),
                          alpha=alphas[i], **bar_args)

    ax.set_rmax(rmax.max())
    return ax
项目:physt    作者:janpipek    | 项目源码 | 文件源码
def pair_bars(first, second, orientation="vertical", kind="bar", **kwargs):
    """Draw two different histograms mirrored in one figure.

    Parameters
    ----------
    first: Histogram1D
    second: Histogram1D
    color1:
    color2:

    Returns
    -------
    plt.Axes
    """
    # TODO: enable vertical as well as horizontal
    _, ax = _get_axes(kwargs)
    color1 = kwargs.pop("color1", "red")
    color2 = kwargs.pop("color2", "blue")
    title = kwargs.pop("title", "{0} - {1}".format(first.name, second.name))
    xlim = kwargs.pop("xlim", (min(first.bin_left_edges[0], first.bin_left_edges[
                      0]), max(first.bin_right_edges[-1], second.bin_right_edges[-1])))

    bar(first * (-1), color=color1, ax=ax, ylim="keep", **kwargs)
    bar(second, color=color2, ax=ax, ylim="keep", **kwargs)
    ax.set_title(title)
    ticks = np.abs(ax.get_yticks())
    if np.allclose(np.rint(ticks), ticks):
        ax.set_yticklabels(ticks.astype(int))
    else:
        ax.set_yticklabels(ticks)
    ax.set_xlim(xlim)
    ax.legend()
    return ax
项目:physt    作者:janpipek    | 项目源码 | 文件源码
def _get_axes(kwargs, use_3d=False, use_polar=False):
    """Prepare the axis to draw into.

    Parameters
    ----------
    use_3d: bool
        If yes, an axis with 3D projection is created.
    use_polar: bool
        If yes, the plot will have polar coordinates.

    Kwargs
    ------
    ax: Optional[plt.Axes]
        An already existing axis to be used.
    figsize: Optional[tuple]
        Size of the new figure (if no axis is given).

    Returns
    ------
    fig : plt.Figure
    ax : plt.Axes | Axes3D
    """
    figsize = kwargs.pop("figsize", default_figsize)
    if "ax" in kwargs:
        ax = kwargs.pop("ax")
        fig = ax.get_figure()
    elif use_3d:
        fig = plt.figure(figsize=figsize)
        ax = fig.add_subplot(111, projection='3d')
    elif use_polar:
        fig = plt.figure(figsize=figsize)
        ax = fig.add_subplot(111, projection='polar')
    else:
        fig, ax = plt.subplots(figsize=figsize)
    return fig, ax
项目:elfi    作者:elfi-dev    | 项目源码 | 文件源码
def _create_axes(axes, shape, **kwargs):
    """Check the axes and create them if necessary.

    Parameters
    ----------
    axes : plt.Axes or arraylike of plt.Axes
    shape : tuple of int
        (x,) or (x,y)
    kwargs

    Returns
    -------
    axes : np.array of plt.Axes
    kwargs : dict
        Input kwargs without items related to creating a figure.

    """
    fig_kwargs = {}
    kwargs['figsize'] = kwargs.get('figsize', (16, 4 * shape[0]))
    for k in ['figsize', 'sharex', 'sharey', 'dpi', 'num']:
        if k in kwargs.keys():
            fig_kwargs[k] = kwargs.pop(k)

    if axes is not None:
        axes = np.atleast_1d(axes)
    else:
        fig, axes = plt.subplots(ncols=shape[1], nrows=shape[0], **fig_kwargs)
        axes = np.atleast_1d(axes)
    return axes, kwargs
项目:elfi    作者:elfi-dev    | 项目源码 | 文件源码
def plot_marginals(samples, selector=None, bins=20, axes=None, **kwargs):
    """Plot marginal distributions for parameters.

    Parameters
    ----------
    samples : OrderedDict of np.arrays
    selector : iterable of ints or strings, optional
        Indices or keys to use from samples. Default to all.
    bins : int, optional
        Number of bins in histogram.
    axes : one or an iterable of plt.Axes, optional

    Returns
    -------
    axes : np.array of plt.Axes

    """
    samples = _limit_params(samples, selector)
    ncols = kwargs.pop('ncols', 5)
    kwargs['sharey'] = kwargs.get('sharey', True)
    shape = (max(1, round(len(samples) / ncols + 0.5)), min(len(samples), ncols))
    axes, kwargs = _create_axes(axes, shape, **kwargs)
    axes = axes.ravel()
    for ii, k in enumerate(samples.keys()):
        axes[ii].hist(samples[k], bins=bins, **kwargs)
        axes[ii].set_xlabel(k)

    return axes
项目:elfi    作者:elfi-dev    | 项目源码 | 文件源码
def plot_traces(result, selector=None, axes=None, **kwargs):
    """Trace plot for MCMC samples.

    The black vertical lines indicate the used warmup.

    Parameters
    ----------
    result : Result_BOLFI
    selector : iterable of ints or strings, optional
        Indices or keys to use from samples. Default to all.
    axes : one or an iterable of plt.Axes, optional
    kwargs

    Returns
    -------
    axes : np.array of plt.Axes

    """
    samples_sel = _limit_params(result.samples, selector)
    shape = (len(samples_sel), result.n_chains)
    kwargs['sharex'] = 'all'
    kwargs['sharey'] = 'row'
    axes, kwargs = _create_axes(axes, shape, **kwargs)

    i1 = 0
    for i2, k in enumerate(result.samples):
        if k in samples_sel:
            for i3 in range(result.n_chains):
                axes[i1, i3].plot(result.chains[i3, :, i2], **kwargs)
                axes[i1, i3].axvline(result.warmup, color='black')

            axes[i1, 0].set_ylabel(k)
            i1 += 1

    for ii in range(result.n_chains):
        axes[-1, ii].set_xlabel('Iterations in Chain {}'.format(ii))

    return axes
项目:gaps    作者:nemanja-m    | 项目源码 | 文件源码
def __init__(self, image, title="Initial problem"):
        aspect_ratio = image.shape[0] / float(image.shape[1])

        width = 8
        height = width * aspect_ratio
        fig = plt.figure(figsize=(width, height), frameon=False)

        # Let image fill the figure
        ax = plt.Axes(fig, [0., 0., 1., .9])
        ax.set_axis_off()
        fig.add_axes(ax)

        self._current_image = ax.imshow(image, aspect="auto", animated=True)
        self.show_fittest(image, title)
项目:HashCode    作者:sbrodehl    | 项目源码 | 文件源码
def plot_with_coverage(d, fpath=None, show=False):
    # plot graph with coverage
    fig = plt.figure()

    ax = plt.Axes(fig, (0, 0, 1, 1))
    ax.set_axis_off()
    fig.add_axes(ax)
    h = d['height']
    w = d['width']
    dpi = 100
    pixel_per_cell = 3
    fig.set_size_inches(pixel_per_cell * w / dpi, pixel_per_cell * h / dpi)
    ax.imshow(d['graph'], cmap=plt.cm.viridis, extent=(0, 1, 0, 1), aspect='auto', interpolation='none')

    routers = []
    g = d['graph']
    for x, row in enumerate(g):
        for y, val in enumerate(row):
            if val == Cell.ConnectedRouter:
                routers.append((x, y))

    coverage = np.zeros((d['height'], d['width']), dtype=np.bool)
    R = d['radius']
    for r in range(len(routers)):
        a, b = routers[r]
        mask = wireless_access(a, b, R, d['original'])
        wx_min, wx_max = np.max([0, (a - R)]), np.min([coverage.shape[0], (a + R + 1)])
        wy_min, wy_max = np.max([0, (b - R)]), np.min([coverage.shape[1], (b + R + 1)])
        # get the submask which is valid
        dx, lx = np.abs(wx_min - (a - R)), wx_max - wx_min
        dy, ly = np.abs(wy_min - (b - R)), wy_max - wy_min
        coverage[wx_min:wx_max, wy_min:wy_max] |= mask[dx:dx + lx, dy:dy + ly].astype(np.bool)

    ax.imshow(coverage, cmap=plt.cm.gray, alpha=0.2, extent=(0, 1, 0, 1), aspect='auto', interpolation='none')

    if fpath is not None:
        plt.savefig(fpath, dpi=dpi)

    if show:
        plt.show()
项目:main    作者:rmkemker    | 项目源码 | 文件源码
def imsave(data, fName, dpi=600):    
    """Save figure with no border

    Parameters
    ----------
    data : numpy array [rows x columns x channels], input image
    fName : string, file path to save image
    dpi : int, dots-per-inch (Default: 600)

    """   
    sh = data.shape

    if len(sh) > 2:    
        data = data.reshape(-1,3)
    else:
        data = data.ravel()

    data = MinMaxScaler().fit_transform(data)*255.0

    fig = plt.figure(frameon=False)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)
    ax.imshow(np.uint8(data.reshape(sh)), aspect='auto')
    fig.savefig(fName,dpi=dpi)
    plt.close()
项目:main    作者:rmkemker    | 项目源码 | 文件源码
def classmap_save(data, fname, cmap=None, num_classes=None, dpi=300):
    """Save a classification (semantic/instance segmentation) map

    Parameters
    ----------
    data : numpy array of integers [rows x columns], classification image
    fName : string, file path to save image
    cmap : custom colormap (Default : None -> Builds it from discrete_cmap)
    num_classes : integer, total number of distinct classes 
        (Default: None -> uses max value in input class map)
    dpi : int, dots-per-inch (Default: 300)
    """    
    fig = plt.figure(frameon=False)
    data = np.uint8(data)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)

    if num_classes is None:
        num_classes = np.max(data)

    if cmap is None:
        cmap = discrete_cmap(num_classes)

    ax.imshow(data, aspect='auto', cmap=cmap,vmin=0, vmax=num_classes-1)
    fig.savefig(fname,dpi=dpi)
    plt.close()
项目:PyDataLondon29-EmbarrassinglyParallelDAWithAWSLambda    作者:SignalMedia    | 项目源码 | 文件源码
def assert_is_valid_plot_return_object(objs):
    import matplotlib.pyplot as plt
    if isinstance(objs, np.ndarray):
        for el in objs.flat:
            assert isinstance(el, plt.Axes), ('one of \'objs\' is not a '
                                              'matplotlib Axes instance, '
                                              'type encountered {0!r}'
                                              ''.format(el.__class__.__name__))
    else:
        assert isinstance(objs, (plt.Artist, tuple, dict)), \
            ('objs is neither an ndarray of Artist instances nor a '
             'single Artist instance, tuple, or dict, "objs" is a {0!r} '
             ''.format(objs.__class__.__name__))
项目:TTClust    作者:tubiana    | 项目源码 | 文件源码
def plot_barplot(clusters_list, logname, size):
    """
    DESCRIPTION
    This function is used to plot the linear barplots.
    Args:
        cluster_number_list (list) : list of cluster label in order or appearance
        output (str) : output logname
    Returns:
        colors_list (list of list) : list of colors in RGBA format
    """
        # order clusters_labels by order of appearance in the trajectory
    clusters_number_ordered = [0] * size
    # Order clusters_labels by cluster order.
    for cluster in clusters_list:
        for frame in cluster.frames:
            clusters_number_ordered[frame] = cluster.id

    # DEFINE COLOR MAP
    cmap = get_cmap(len(clusters_list))

    data = np.asmatrix(clusters_number_ordered)
    fig = plt.figure(figsize=(10,1))
    # move the graphic into the corner
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    # remove axes
    ax.set_axis_off()
    # set axes
    fig.add_axes(ax)
    # create graphic
    im = ax.imshow(data,aspect='auto', interpolation='none', cmap=cmap)
    colors_list = (im.cmap(im.norm(np.unique(clusters_number_ordered))))

    plt.savefig("{0}/{0}-linear.png".format(logname), dpi=DPI)
    plt.close()
    return colors_list
项目:workspace    作者:nojima    | 项目源码 | 文件源码
def visualize(ax: plt.Axes, dataset: DataSet, model: AutoEncoder) -> None:
    x = Variable(dataset.input)
    code = model.encode(x).data

    for t in np.unique(dataset.target):
        mask = dataset.target == t
        ax.scatter(code[mask, 0], code[mask, 1])
项目:workspace    作者:nojima    | 项目源码 | 文件源码
def visualize_countries(model: Word2Vec, vocabulary: Vocabulary, ax: plt.Axes = None):
    countries = ['u.s.', 'u.k.', 'italy', 'korea', 'china', 'germany', 'japan', 'france', 'russia', 'egypt']
    capitals = ['washington', 'london', 'rome', 'seoul', 'beijing', 'berlin', 'tokyo', 'paris', 'moscow', 'cairo']

    vectors_2d = project_to_2d_by_pca(model, vocabulary)

    if ax is None:
        fig, ax = plt.subplots()
    else:
        fig = None

    # Plot countries
    country_ids = [vocabulary.to_id(word) for word in countries]
    country_vectors = vectors_2d[country_ids]
    ax.scatter(country_vectors[:, 0], country_vectors[:, 1], c='blue', alpha=0.7)
    for i, label in enumerate(countries):
        ax.annotate(label, (country_vectors[i, 0], country_vectors[i, 1]))

    # Plot capitals
    capital_ids = [vocabulary.to_id(word) for word in capitals]
    capital_vectors = vectors_2d[capital_ids]
    ax.scatter(capital_vectors[:, 0], capital_vectors[:, 1], c='orange', alpha=0.7)
    for i, label in enumerate(capitals):
        ax.annotate(label, (capital_vectors[i, 0], capital_vectors[i, 1]))

    # Draw arrows
    for country, capital in zip(countries, capitals):
        v1 = vectors_2d[vocabulary.to_id(country)]
        v2 = vectors_2d[vocabulary.to_id(capital)]
        ax.arrow(v1[0], v1[1], (v2 - v1)[0], (v2 - v1)[1], alpha=0.5)

    if fig is not None:
        fig.show()
项目:kite    作者:pyrocko    | 项目源码 | 文件源码
def setCanvas(self, **kwargs):
        """Set canvas to plot in

        :param figure: Matplotlib figure to plot in
        :type figure: :py:class:`matplotlib.Figure`
        :param axes: Matplotlib axes to plot in
        :type axes: :py:class:`matplotlib.Axes`
        :raises: TypeError
        """
        axes = kwargs.get('axes', None)
        figure = kwargs.get('figure', None)

        if isinstance(axes, plt.Axes):
            self.fig, self.ax = axes.get_figure(), axes
            self._show_plt = False
        elif isinstance(figure, plt.Figure):
            self.fig, self.ax = figure, figure.gca()
            self._show_plt = False
        elif axes is None and figure is None and self.fig is None:
            self.fig, self.ax = plt.subplots(1, 1)
            self._show_plt = True
        else:
            raise TypeError('axes has to be of type matplotlib.Axes. '
                            'figure has to be of type matplotlib.Figure')
        self.image = AxesImage(self.ax)
        self.ax.add_artist(self.image)
项目:kite    作者:pyrocko    | 项目源码 | 文件源码
def _initImagePlot(self, **kwargs):
        """ Initiate the plot

        :param figure: Matplotlib figure to plot in
        :type figure: :py:class:`matplotlib.Figure`
        :param axes: Matplotlib axes to plot in
        :type axes: :py:class:`matplotlib.Axes`
        """
        self.setCanvas(**kwargs)

        self.setColormap(kwargs.get('cmap', 'RdBu'))
        self.colormapAdjust()

        self.ax.set_xlim((0, self._scene.frame.E.size))
        self.ax.set_ylim((0, self._scene.frame.N.size))
        self.ax.set_aspect('equal')
        self.ax.invert_yaxis()

        self.ax.set_title(self.title)

        def close_figure(ev):
            self.fig = None
            self.ax = None
        try:
            self.fig.canvas.mpl_connect('close_event', close_figure)
        # specify!
        except:
            pass
项目:kite    作者:pyrocko    | 项目源码 | 文件源码
def plot(self, **kwargs):
        """Placeholder in prototype class

        :param figure: Matplotlib figure to plot in
        :type figure: :py:class:`matplotlib.Figure`
        :param axes: Matplotlib axes to plot in
        :type axes: :py:class:`matplotlib.Axes`
        :param **kwargs: kwargs are passed into `plt.imshow`
        :type **kwargs: dict
        :raises: NotImplemented
        """
        raise NotImplemented
        self._initImagePlot(**kwargs)
        if self._show_plt:
            plt.show()
项目:kite    作者:pyrocko    | 项目源码 | 文件源码
def plot(self, component='displacement', **kwargs):
        """Plots any component fom Scene
        The following components are recognizes

        - 'cartesian.dE'
        - 'cartesian.dN'
        - 'cartesian.dU'
        - 'displacement'
        - 'phi'
        - 'theta'

        :param **kwargs: Keyword args forwarded to `matplotlib.plt.imshow()`
        :type **kwargs: {dict}
        :param component: Component to plot
['cartesian.dE', 'cartesian.dN', 'cartesian.dU',
'displacement', 'phi', 'theta']
        :type component: {string}, optional
        :param axes: Axes instance to plot in, defaults to None
        :type axes: :py:class:`matplotlib.Axes`, optional
        :param figure: Figure instance to plot in, defaults to None
        :type figure: :py:class:`matplotlib.Figure`, optional
        :param **kwargs: kwargs are passed into `plt.imshow`
        :type **kwargs: dict
        :returns: Imshow instance
        :rtype: :py:class:`matplotlib.image.AxesImage`
        :raises: AttributeError
        """
        self._initImagePlot(**kwargs)
        self.component = component
        self.title = self.components_available[component]

        if self._show_plt:
            plt.show()
项目:pactools    作者:pactools    | 项目源码 | 文件源码
def plot(self, axs=None, fscale='log'):
        """
        Plots the impulse response and the transfer function of the filter.
        """
        # validate figure
        fig_passed = axs is not None
        if axs is None:
            fig, axs = plt.subplots(nrows=2)
        else:
            axs = np.atleast_1d(axs)
            if np.any([not isinstance(ax, plt.Axes) for ax in axs]):
                raise TypeError('axs must be a list of matplotlib Axes, got {}'
                                ' instead.'.format(type(axs)))
            # test if is figure and has 2 axes
            if len(axs) < 2:
                raise ValueError('Passed figure must have at least two axes'
                                 ', given figure has {}.'.format(len(axs)))
            fig = axs[0].figure

        # compute periodogram
        fft_length = max(int(2 ** np.ceil(np.log2(self.fir.shape[0]))), 2048)
        s = Spectrum(fft_length=fft_length, block_length=self.fir.size,
                     step=None, fs=self.fs, wfunc=np.ones, donorm=False)
        s.periodogram(self.fir)
        s.plot('Transfer function of FIR filter', fscale=fscale,
               axes=axs[0])

        # plots
        axs[1].plot(self.fir)
        axs[1].set_title('Impulse response of FIR filter')
        axs[1].set_xlabel('Samples')
        axs[1].set_ylabel('Amplitude')
        if not fig_passed:
            fig.tight_layout()
        return fig
项目:decoding_challenge_cortana_2016_3rd    作者:kingjr    | 项目源码 | 文件源码
def _imshow_tfr(ax, ch_idx, tmin, tmax, vmin, vmax, onselect, ylim=None,
                tfr=None, freq=None, vline=None, x_label=None, y_label=None,
                colorbar=False, picker=True, cmap='RdBu_r', title=None,
                hline=None):
    """ Aux function to show time-freq map on topo """
    import matplotlib.pyplot as plt
    from matplotlib.widgets import RectangleSelector

    extent = (tmin, tmax, freq[0], freq[-1])
    img = ax.imshow(tfr[ch_idx], extent=extent, aspect="auto", origin="lower",
                    vmin=vmin, vmax=vmax, picker=picker, cmap=cmap)
    if isinstance(ax, plt.Axes):
        if x_label is not None:
            ax.set_xlabel(x_label)
        if y_label is not None:
            ax.set_ylabel(y_label)
    else:
        if x_label is not None:
            plt.xlabel(x_label)
        if y_label is not None:
            plt.ylabel(y_label)
    if colorbar:
        plt.colorbar(mappable=img)
    if title:
        plt.title(title)
    if not isinstance(ax, plt.Axes):
        ax = plt.gca()
    ax.RS = RectangleSelector(ax, onselect=onselect)  # reference must be kept
项目:githubgraph    作者:0x0FFF    | 项目源码 | 文件源码
def update_animation(self, i):
        print "Frame %d / %d" % (i, self.iters)
        self.recalc_node_positions()
        plt.clf()
        ax = plt.Axes(self.fig, [0., 0., 1., 1.])
        ax.autoscale(False)
        ax.set_axis_bgcolor('#141414')
        self.fig.add_axes(ax)
        self.g.remove_edges_from(self.g.edges())
        self.g.add_edges_from(self.edges[self.dates[self.current_date]])
        nx.draw_networkx_edges(self.g, self.pos, edge_color='#2F2F2F')
        self.real_sizes = [2000.0 * self.sizes[node] for node in self.g.nodes()]
        self.draw_glowing_nodes(self.real_sizes)
        self.draw_labels(self.g, self.pos, sizes=self.real_sizes, labels=self.nodes)
        self.tick += 1
        self.inc_sizes()
        if self.tick >= self.ticks_in_week:
            self.current_date += 1
            if self.current_date == len(self.dates):
                self.current_date -= 1
            self.tick = 0
            self.curr_edges = self.recalc_edges()
            self.sizes = self.nodesizes[self.dates[self.current_date]]
            self.size_inc = self.calc_size_inc()
        #if i == self.iters - 1:
        #    self.save_positions()
        self.draw_timeline()
        return
项目:structured-output-ae    作者:sbelharbi    | 项目源码 | 文件源码
def plot_over_img_seg(self, img, x, y, x_pr, y_pr, bb_gt, tag_oc=None):
        """Plot the landmarks over the image with the bbox."""
        plt.close("all")
        fig = plt.figure(frameon=False)  # , figsize=(15, 10.8), dpi=200
        ax = plt.Axes(fig, [0., 0., 1., 1.])
        ax.set_axis_off()
        bb_gt = [int(xx) for xx in bb_gt]
        hight, width = bb_gt[3]-bb_gt[1], bb_gt[2]-bb_gt[0]
        if tag_oc is None:
            img_oc = copy.deepcopy(img)
        elif tag_oc is "left":
            img_oc = copy.deepcopy(img)
            p = int((20/50.) * width)  # we took only 20 pixels from 50.
            img_oc[bb_gt[1]:bb_gt[3],
                   bb_gt[0]:bb_gt[0]+p, :] = np.uint8(255/2.)
        elif tag_oc is "right":
            img_oc = copy.deepcopy(img)
            p = int((20/50.) * width)  # we took only 20 pixels from 50.
            img_oc[bb_gt[1]:bb_gt[3],
                   bb_gt[2]-p:bb_gt[2], :] = np.uint8(255/2.)
        elif tag_oc is "up":
            img_oc = copy.deepcopy(img)
            p = int((20/50.) * hight)  # we took only 20 pixels from 50.
            img_oc[bb_gt[1]:bb_gt[1]+p,
                   bb_gt[0]:bb_gt[2], :] = np.uint8(255/2.)
        elif tag_oc is "down":
            img_oc = copy.deepcopy(img)
            p = int((20/50.) * hight)  # we took only 20 pixels from 50.
            img_oc[bb_gt[3]-p:bb_gt[3],
                   bb_gt[0]:bb_gt[2], :] = np.uint8(255/2.)
        elif tag_oc is "middle":
            img_oc = copy.deepcopy(img)
            p1 = int((15/50.) * hight)  # we took only from 15 pixels from 50.
            p2 = int((35/50.) * hight)  # we took only to 35 pixels from 50.
            img_oc[bb_gt[1]+p1:bb_gt[1]+p2,
                   bb_gt[0]:bb_gt[2], :] = np.uint8(255/2.)
        ax.imshow(cv2.cvtColor(img_oc, cv2.COLOR_BGR2RGB), aspect="auto")
        for i in xrange(68):
            ax.plot([x[i], x_pr[i]], [y[i], y_pr[i]], '-r')

        fig.add_axes(ax)

        return fig
项目:F_UNCLE    作者:fraserphysics    | 项目源码 | 文件源码
def plot_basis(self, axes=None, fig=None, labels=[], linstyles=[]):
        """Plots the basis function and their first and second derivatives

        Args:
            fig(plt.Figure): A valid figure object on which to plot
            axes(plt.Axes): A valid axes, *Ignored*
            labels(list): The labels, *Ignored*
            linestyles(list): The linestyles, *Ignored*
        Return:
            (plt.Figure): The figure

        """

        if fig is None:
            fig = plt.figure()
        else:
            fig = fig
        # end

        dof_init = copy.deepcopy(self.get_dof())

        basis = []
        dbasis = []
        ddbasis = []

        v_list = np.linspace(self.get_option('spline_min'),
                             self.get_option('spline_max'),
                             300)

        for i, coeff in enumerate(dof_init):
            new_dof = np.zeros(dof_init.shape[0])
            new_dof[i] = 1.0  # coeff
            tmp_spline = self.update_dof(new_dof)
            basis.append(tmp_spline(v_list))
            dbasis.append(tmp_spline.derivative(n=1)(v_list))
            ddbasis.append(tmp_spline.derivative(n=2)(v_list))
        # end

        basis = np.array(basis)
        dbasis = np.array(dbasis)
        ddbasis = np.array(ddbasis)

        ax1 = fig.add_subplot(311)
        ax2 = fig.add_subplot(312)
        ax3 = fig.add_subplot(313)

        knots = tmp_spline.get_t()

        for i in range(basis.shape[0]):
            ax1.plot(v_list, basis[i, :], label='dof{:02d}'.format(i))
            ax1.plot(knots, np.zeros(knots.shape), 'xk')
            ax2.plot(v_list, dbasis[i, :])
            ax2.plot(knots, np.zeros(knots.shape), 'xk')
            ax3.plot(v_list, ddbasis[i, :])
            ax3.plot(knots, np.zeros(knots.shape), 'xk')
        ax1.legend(loc='best')
        return fig
项目:F_UNCLE    作者:fraserphysics    | 项目源码 | 文件源码
def plot_fisher_data(self, fisher_data, axes=None, fig=None,
                         linestyles=[], labels=[]):
        """

        Args:
            fisher_dat(tuple): Data from the fisher_decomposition function
                               *see docscring for definition*

        Keyword Args:
            axes(plt.Axes): *Ignored*
            fig(plt.Figure): A valid figure to plot on
            linestyles(list): A list of valid linestyles *Ignored*
            labels(list): A list of labels *Ignored*
        """

        if fig is None:
            fig = plt.Figure()
        else:
            pass
        # end

        ax1 = plt.subplot(211)
        ax2 = plt.subplot(212)

        eigs = fisher_data[0]
        eig_vects = fisher_data[1]
        eig_func = fisher_data[2]
        indep = fisher_data[3]

#        ax1.bar(np.arange(eigs.shape[0]), eigs, width=0.9, color='black',
#                edgecolor='none', orientation='vertical')
        ax1.semilogy(eigs, 'sk')
        ax1.set_xlabel("Eigenvalue number")
        ax1.set_ylabel(r"Eigenvalue / Pa$^{-2}$")
        ax1.set_xlim(-0.5, len(eigs) - 0.5)
        ax1.set_ylim([0.1 * min(eigs[np.nonzero(eigs)]), 10 * max(eigs)])
        ax1.xaxis.set_major_locator(MultipleLocator(1))
        ax1.xaxis.set_major_formatter(FormatStrFormatter('%d'))

        styles = ['-g', '-.b', '--m', ':k', '-c', '-.y', '--r'] *\
            int(math.ceil(eig_func.shape[0] / 7.0))

        for i in range(eig_func.shape[0]):
            ax2.plot(indep, eig_func[i], styles[i],
                     label="{:d}".format(i))
        # end

        ax2.legend(loc='best')
        ax2.get_legend().set_title("Eigen-\nfunctions", prop={'size': 7})
        ax2.set_xlabel(r"Specific volume / cm$^3$ g$^{-1}$")
        ax2.set_ylabel("Eigenfunction response / Pa")

        fig.tight_layout()

        return fig
项目:F_UNCLE    作者:fraserphysics    | 项目源码 | 文件源码
def plot_convergence(self, hist, axes=None, linestyles=['-k'], labels=[]):
        """

        Args:
            hist(tuple): Convergence history, elements
                0. (list): MAP history
                1. (list): DOF history

        Keyword Args:
            axes(plt.Axes): The axes on which to plot the figure, if None,
                creates a new figure object on which to plot.
            linestyles(list): Strings for the linestyles
            labels(list): Strings for the labels

        """

        if axes is None:
            fig = plt.figure()
            ax1 = fig.gca()
        else:
            fig = None
            ax1 = axes
        # end

        ax1.semilogy(-np.array(hist[0]), linestyles[0])

        ax1.xaxis.set_major_locator(MultipleLocator(1))
        ax1.xaxis.set_major_formatter(FormatStrFormatter('%d'))

        ax1.set_xlabel('Iteration number')
        ax1.set_ylabel('Negative a posteriori log likelihood')

        # fig = plt.figure()
        # ax1 = fig.add_subplot(121)
        # ax2 = fig.add_subplot(122)
        # for i in range(dof_hist.shape[1]):
        #     ax1.plot(dof_hist[:, i]/dof_hist[0, i])
        # # end
        # fig.suptitle('Convergence of iterative process')
        # ax1.set_ylabel('Spline knot value')
        # ax1.set_xlabel('Iteration number')
        # fig.savefig('EOS_convergence.pdf')
项目:F_UNCLE    作者:fraserphysics    | 项目源码 | 文件源码
def plot(self, data, axes=None, fig=None, linestyles=['-k'], labels=[]):
        """Plots the object

        Args:
            data(tuple): The output from a call to a Sphere object

        Keyword Args:
            axes(plt.Axes): The axes on which to plot *Ignored*
            fig(plt.Figure): The figure on which to plot
            linestyles(list): Strings for the linestyles
            labels(list): Strings for the labels

        Return:
            (plt.Figure): A reference to the figure containing the plot

        """

        if fig is None:
            fig = plt.figure()
        else:
            pass
        # end

        ax1 = fig.add_subplot(321)
        ax2 = fig.add_subplot(322)
        ax3 = fig.add_subplot(323)
        ax4 = fig.add_subplot(324)
        ax5 = fig.add_subplot(325)
        ax6 = fig.add_subplot(326)

        ax1.plot(data[0], data[1][1], linestyles[0])
        ax1.set_xlabel('Time from detonation / s')
        ax1.set_ylabel('Radius of sphere / cm')

        ax2.plot(data[0], data[1][0])
        ax2.set_ylabel(r'Velocity of sphere / cm s$^{-1}$')
        ax3.plot(data[0], data[1][3])
        ax3.set_ylabel(r'Specific volume / cm$^{3}$ g$^{-1}$')
        ax4.plot(data[0], data[1][4])
        ax4.plot(data[0], np.array(data[1][5]) * 1E3, '-k')
        ax4.set_ylabel(r'Pressure / Pa')
        ax5.plot(data[0], data[1][2])
        ax5.set_ylabel(r'Thickness / cm')
        ax6.plot(data[0], data[1][6])
        ax6.set_ylabel(r'Strain in material / cm cm$^{-1}$')

        return fig
项目:GazePointHeatMap    作者:r0ehre    | 项目源码 | 文件源码
def draw_display(dispsize, imagefile=None):
    """Returns a matplotlib.pyplot Figure and its axes, with a size of
    dispsize, a black background colour, and optionally with an image drawn
    onto it

    arguments

    dispsize        -   tuple or list indicating the size of the display,
                    e.g. (1024,768)

    keyword arguments

    imagefile       -   full path to an image file over which the heatmap
                    is to be laid, or None for no image; NOTE: the image
                    may be smaller than the display size, the function
                    assumes that the image was presented at the centre of
                    the display (default = None)

    returns
    fig, ax     -   matplotlib.pyplot Figure and its axes: field of zeros
                    with a size of dispsize, and an image drawn onto it
                    if an imagefile was passed
    """

    # construct screen (black background)
    screen = numpy.zeros((dispsize[1], dispsize[0], 3), dtype='float32')
    # if an image location has been passed, draw the image
    if imagefile != None:
        # check if the path to the image exists
        if not os.path.isfile(imagefile):
            raise Exception("ERROR in draw_display: imagefile not found at '%s'" % imagefile)
        # load image
        img = image.imread(imagefile)

        # width and height of the image
        w, h = len(img[0]), len(img)
        # x and y position of the image on the display
        x = dispsize[0] / 2 - w / 2
        y = dispsize[1] / 2 - h / 2
        # draw the image on the screen
        screen[y:y + h, x:x + w, :] += img
    # dots per inch
    dpi = 100.0
    # determine the figure size in inches
    figsize = (dispsize[0] / dpi, dispsize[1] / dpi)
    # create a figure
    fig = pyplot.figure(figsize=figsize, dpi=dpi, frameon=False)
    ax = pyplot.Axes(fig, [0, 0, 1, 1])
    ax.set_axis_off()
    fig.add_axes(ax)
    # plot display
    ax.axis([0, dispsize[0], 0, dispsize[1]])
    ax.imshow(screen)  # , origin='upper')

    return fig, ax
项目:physt    作者:janpipek    | 项目源码 | 文件源码
def bar(h1, ax, errors=False, **kwargs):
    """Bar plot of 1D histograms.

    Parameters
    ----------
    h1: Histogram1D
    errors: bool
        Whether to draw error bars.
    value_format:
        A function converting or str
    show_stats: bool
        If True, display a small box with statistical info

    Returns
    -------
    plt.Axes
    """
    show_stats = kwargs.pop("show_stats", False)
    show_values = kwargs.pop("show_values", False)
    value_format = kwargs.pop("value_format", None)
    density = kwargs.pop("density", False)
    cumulative = kwargs.pop("cumulative", False)
    label = kwargs.pop("label", h1.name)

    data = get_data(h1, cumulative=cumulative, density=density)
    # transformed = transform_data(data, kwargs)

    if "cmap" in kwargs:
        cmap = _get_cmap(kwargs)
        _, cmap_data = _get_cmap_data(data, kwargs)
        colors = cmap(cmap_data)
    else:
        colors = kwargs.pop("color", None)

    _apply_xy_lims(ax, h1, data, kwargs)
    _add_ticks(ax, h1, kwargs)

    if errors:
        err_data = get_err_data(h1, cumulative=cumulative, density=density)
        kwargs["yerr"] = err_data
        if "ecolor" not in kwargs:
            kwargs["ecolor"] = "black"

    _add_labels(ax, h1, kwargs)
    ax.bar(h1.bin_left_edges, data, h1.bin_widths, align="edge",
           label=label, color=colors, **kwargs)

    if show_values:
        _add_values(ax, h1, data, value_format=value_format)
    if show_stats:
        _add_stats_box(h1, ax)

    return ax
项目:physt    作者:janpipek    | 项目源码 | 文件源码
def scatter(h1, ax, errors=False, **kwargs):
    """Scatter plot of 1D histogram.

    Parameters
    ----------
    h1: Histogram1D
    errors: bool
        Whether to draw error bars.

    Returns
    -------
    plt.Axes
    """
    show_stats = kwargs.pop("show_stats", False)
    show_values = kwargs.pop("show_values", False)
    density = kwargs.pop("density", False)
    cumulative = kwargs.pop("cumulative", False)
    value_format = kwargs.pop("value_format", None)

    data = get_data(h1, cumulative=cumulative, density=density)
    # transformed = transform_data(data, kwargs)

    if "cmap" in kwargs:
        cmap = _get_cmap(kwargs)
        _, cmap_data = _get_cmap_data(data, kwargs)
        kwargs["color"] = cmap(cmap_data)
    else:
        kwargs["color"] = kwargs.pop("color", "blue")

    _apply_xy_lims(ax, h1, data, kwargs)
    _add_ticks(ax, h1, kwargs)
    _add_labels(ax, h1, kwargs)

    if errors:
        err_data = get_err_data(h1, cumulative=cumulative, density=density)
        ax.errorbar(h1.bin_centers, data, yerr=err_data, fmt=kwargs.pop("fmt", "o"),
                    ecolor=kwargs.pop("ecolor", "black"), ms=0)
    ax.scatter(h1.bin_centers, data, **kwargs)

    if show_values:
        _add_values(ax, h1, data, value_format=value_format)
    if show_stats:
        _add_stats_box(h1, ax)
    return ax
项目:elfi    作者:elfi-dev    | 项目源码 | 文件源码
def plot_pairs(samples, selector=None, bins=20, axes=None, **kwargs):
    """Plot pairwise relationships as a matrix with marginals on the diagonal.

    The y-axis of marginal histograms are scaled.

     Parameters
    ----------
    samples : OrderedDict of np.arrays
    selector : iterable of ints or strings, optional
        Indices or keys to use from samples. Default to all.
    bins : int, optional
        Number of bins in histograms.
    axes : one or an iterable of plt.Axes, optional

    Returns
    -------
    axes : np.array of plt.Axes

    """
    samples = _limit_params(samples, selector)
    shape = (len(samples), len(samples))
    edgecolor = kwargs.pop('edgecolor', 'none')
    dot_size = kwargs.pop('s', 2)
    kwargs['sharex'] = kwargs.get('sharex', 'col')
    kwargs['sharey'] = kwargs.get('sharey', 'row')
    axes, kwargs = _create_axes(axes, shape, **kwargs)

    for i1, k1 in enumerate(samples):
        min_samples = samples[k1].min()
        max_samples = samples[k1].max()
        for i2, k2 in enumerate(samples):
            if i1 == i2:
                # create a histogram with scaled y-axis
                hist, bin_edges = np.histogram(samples[k1], bins=bins)
                bar_width = bin_edges[1] - bin_edges[0]
                hist = (hist - hist.min()) * (max_samples - min_samples) / (
                    hist.max() - hist.min())
                axes[i1, i2].bar(bin_edges[:-1], hist, bar_width, bottom=min_samples, **kwargs)
            else:
                axes[i1, i2].scatter(
                    samples[k2], samples[k1], s=dot_size, edgecolor=edgecolor, **kwargs)

        axes[i1, 0].set_ylabel(k1)
        axes[-1, i1].set_xlabel(k1)

    return axes
项目:picasso    作者:jungmannlab    | 项目源码 | 文件源码
def plotPlate(selection, selectioncolors, platename):
    inch = 25.4
    radius = 4.5/inch  # diameter of 96 well plates is 9mm
    radiusc = 4/inch
    circles = dict()
    rows = 8
    cols = 12
    colsStr = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12']
    rowsStr = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H']
    rowsStr = rowsStr[::-1]

    fig = plt.figure(frameon=False)
    fig.set_size_inches(5,  8)
    ax = plt.Axes(fig,  [0.,  0.,  1.,  1.],)
    ax.set_axis_off()
    fig.add_axes(ax)

    ax.cla()
    plt.axis('equal')
    for xcord in range(0, cols):
        for ycord in range(0, rows):
            string = rowsStr[ycord]+colsStr[xcord]
            xpos = xcord*radius*2+radius
            ypos = ycord*radius*2+radius
            if string in selection:
                # circle = plt.Circle((xpos, ypos), radiusc,  facecolor='black', edgecolor='black')
                circle = plt.Circle((xpos, ypos), radiusc,  facecolor=selectioncolors[selection.index(string)], edgecolor='black')
                ax.text(xpos,  ypos,  string,  fontsize=10,  color='white', horizontalalignment='center',
                        verticalalignment='center')
            else:
                circle = plt.Circle((xpos, ypos), radiusc,  facecolor='white', edgecolor='black')
            ax.add_artist(circle)
    # inner rectangle
    ax.add_patch(patches.Rectangle((0,  0), cols*2*radius, rows*2*radius, fill=False))
    # outer Rectangle
    ax.add_patch(patches.Rectangle((0-2*radius,  0), (cols+1)*2*radius, (rows+1)*2*radius, fill=False))

    # add rows and columns
    for xcord in range(0, cols):
        ax.text(xcord*2*radius+radius,  rows*2*radius+radius,  colsStr[xcord],  fontsize=10,  color='black', horizontalalignment='center',
                verticalalignment='center')
    for ycord in range(0, rows):
        ax.text(-radius,  ycord*2*radius+radius,  rowsStr[ycord],  fontsize=10,  color='black', horizontalalignment='center',
                verticalalignment='center')

    ax.set_xlim([-2*radius, cols*2*radius])
    ax.set_ylim([0, (rows+1)*2*radius])
    plt.title(platename+' - '+str(len(selection))+' Staples')
    ax.set_xticks([])
    ax.set_yticks([])
    xsize = 13*2*radius
    ysize = 9*2*radius
    fig.set_size_inches(xsize,  ysize)

    return fig
项目:artemis    作者:QUVA-Lab    | 项目源码 | 文件源码
def _create_subplot(fig = None, layout = None, position = None, **subplot_args):

    if layout is None:
        layout = _newplot_settings['layout']
    if fig is None:
        fig = plt.gcf()
    n = len(fig.axes)
    n_rows, n_cols = (1, n+1) if layout in ('h', 'horizontal') else (n+1, 1) if layout in ('v', 'vertical') else \
        vector_length_to_tile_dims(n+1) if layout in ('g', 'grid') else bad_value(layout)
    for i in range(n):
        fig.axes[i].change_geometry(n_rows, n_cols, i+1)

    for arg in ('sharex', 'sharey'):
        if isinstance(_newplot_settings[arg], plt.Axes):
            subplot_args[arg]=_newplot_settings[arg]

    ax = fig.add_subplot(n_rows, n_cols, n+1, **subplot_args)

    if _newplot_settings['xlabel'] is not None:
        ax.set_xlabel(_newplot_settings['xlabel'])
    if _newplot_settings['ylabel'] is not None:
        ax.set_ylabel(_newplot_settings['ylabel'])

    if _newplot_settings['xlim'] is not None:
        ax.set_xlim(_newplot_settings['xlim'])
    if _newplot_settings['ylim'] is not None:
        ax.set_ylim(_newplot_settings['ylim'])

    if _newplot_settings['grid']:
        plt.grid()

    for arg in ('sharex', 'sharey'):
        if _newplot_settings[arg] is True:
            _newplot_settings[arg]=ax

    if not _newplot_settings['show_x']:
        ax.tick_params(axis='x', labelbottom='off')
        # ax.get_xaxis().set_visible(False)
    if not _newplot_settings['show_y']:
        ax.tick_params(axis='y', labelleft='off')
        # ax.get_yaxis().set_visible(False)
    return ax