是否有在matplotlib中制作散点图矩阵的函数?

hat*_*rix 52 python matplotlib scatter-plot

散点图矩阵的示例

在此输入图像描述

matplotlib.pyplot中有这样的功能吗?

Rom*_*kar 100

对于那些不想定义自己的函数的人来说,Python中有一个很棒的数据分析库,叫做Pandas,可以找到scatter_matrix()方法:

from pandas.plotting import scatter_matrix
df = pd.DataFrame(np.random.randn(1000, 4), columns = ['a', 'b', 'c', 'd'])
scatter_matrix(df, alpha = 0.2, figsize = (6, 6), diagonal = 'kde')
Run Code Online (Sandbox Code Playgroud)

在此输入图像描述

  • +1这将教我去寻找Python功能,然后再看看它是否已经在熊猫中了.第1步:总是问,熊猫中是否已经存在?`pd.scatter_matrix(DF); plt.show()`.难以置信. (5认同)
  • 嗨,为什么只有部分子图中有一个网格?可以修改(全部或全部)?谢谢 (2认同)
  • 在 matplotlib 散点图矩阵中放置 kde 是一项极限运动。我爱熊猫。 (2认同)

Joe*_*ton 22

一般来说,matplotlib通常不包含在多个轴对象上操作的绘图函数(在本例中为子图).期望的是你会编写一个简单的函数来将所有东西串在一起然而你喜欢.

我不太确定你的数据是什么样的,但是只需构建一个从头开始执行此操作的函数就很简单了.如果您总是要使用结构化数据或rec数组,那么您可以简化这一过程.(即,每个数据系列始终都有一个名称,因此您可以省略必须指定名称.)

举个例子:

import itertools
import numpy as np
import matplotlib.pyplot as plt

def main():
    np.random.seed(1977)
    numvars, numdata = 4, 10
    data = 10 * np.random.random((numvars, numdata))
    fig = scatterplot_matrix(data, ['mpg', 'disp', 'drat', 'wt'],
            linestyle='none', marker='o', color='black', mfc='none')
    fig.suptitle('Simple Scatterplot Matrix')
    plt.show()

def scatterplot_matrix(data, names, **kwargs):
    """Plots a scatterplot matrix of subplots.  Each row of "data" is plotted
    against other rows, resulting in a nrows by nrows grid of subplots with the
    diagonal subplots labeled with "names".  Additional keyword arguments are
    passed on to matplotlib's "plot" command. Returns the matplotlib figure
    object containg the subplot grid."""
    numvars, numdata = data.shape
    fig, axes = plt.subplots(nrows=numvars, ncols=numvars, figsize=(8,8))
    fig.subplots_adjust(hspace=0.05, wspace=0.05)

    for ax in axes.flat:
        # Hide all ticks and labels
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)

        # Set up ticks only on one side for the "edge" subplots...
        if ax.is_first_col():
            ax.yaxis.set_ticks_position('left')
        if ax.is_last_col():
            ax.yaxis.set_ticks_position('right')
        if ax.is_first_row():
            ax.xaxis.set_ticks_position('top')
        if ax.is_last_row():
            ax.xaxis.set_ticks_position('bottom')

    # Plot the data.
    for i, j in zip(*np.triu_indices_from(axes, k=1)):
        for x, y in [(i,j), (j,i)]:
            axes[x,y].plot(data[x], data[y], **kwargs)

    # Label the diagonal subplots...
    for i, label in enumerate(names):
        axes[i,i].annotate(label, (0.5, 0.5), xycoords='axes fraction',
                ha='center', va='center')

    # Turn on the proper x or y axes ticks.
    for i, j in zip(range(numvars), itertools.cycle((-1, 0))):
        axes[j,i].xaxis.set_visible(True)
        axes[i,j].yaxis.set_visible(True)

    return fig

main()
Run Code Online (Sandbox Code Playgroud)

在此输入图像描述

  • 哇,很多新功能!是的,当你掌握了模块时并不太难......但不像在R.中调用`pair`那么简单.:) (3认同)
  • 在我看来,最大的优势是python的灵活性.R是一种非常出色的领域特定语言,如果您只是想进行统计分析,那就是无与伦比的.Python是一种很好的通用编程语言,你会真正开始看到更大程序的好处.一旦你开始想要一个带有交互式gui的程序,它可以从网上抓取数据,解析一些随机二进制文件格式,进行分析,并将其全部绘制出来,一般的编程语言可以显示出很多优点.当然,对于很多语言来说都是如此,但我更喜欢python.:) (2认同)

sus*_*mit 13

您还可以使用Seaborn的pairplot功能:

import seaborn as sns
sns.set()
df = sns.load_dataset("iris")
sns.pairplot(df, hue="species")
Run Code Online (Sandbox Code Playgroud)


小智 10

感谢您分享您的代码!你为我们找到了所有困难的东西.当我使用它时,我注意到一些看起来不太正确的小事.

  1. [FIX#1]轴抽搐没有像我期望的那样排列(也就是说,在上面的例子中,你应该能够在所有图中的任何点绘制垂直和水平线,并且线应该穿过相应的指向其他图,但现在它不会发生.

  2. [FIX#2]如果您正在绘制奇数个变量,则右下角轴不会拉出正确的xtics或ytics.它只是将其保留为默认的0..1滴答.

  3. 不是修复,但我明确输入它是可选的names,因此它将xi变量i 的默认值放在对角线位置.

您将在下面找到代码的更新版本,以解决这两点,否则将保留您的代码之美.

import itertools
import numpy as np
import matplotlib.pyplot as plt

def scatterplot_matrix(data, names=[], **kwargs):
    """
    Plots a scatterplot matrix of subplots.  Each row of "data" is plotted
    against other rows, resulting in a nrows by nrows grid of subplots with the
    diagonal subplots labeled with "names".  Additional keyword arguments are
    passed on to matplotlib's "plot" command. Returns the matplotlib figure
    object containg the subplot grid.
    """
    numvars, numdata = data.shape
    fig, axes = plt.subplots(nrows=numvars, ncols=numvars, figsize=(8,8))
    fig.subplots_adjust(hspace=0.0, wspace=0.0)

    for ax in axes.flat:
        # Hide all ticks and labels
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)

        # Set up ticks only on one side for the "edge" subplots...
        if ax.is_first_col():
            ax.yaxis.set_ticks_position('left')
        if ax.is_last_col():
            ax.yaxis.set_ticks_position('right')
        if ax.is_first_row():
            ax.xaxis.set_ticks_position('top')
        if ax.is_last_row():
            ax.xaxis.set_ticks_position('bottom')

    # Plot the data.
    for i, j in zip(*np.triu_indices_from(axes, k=1)):
        for x, y in [(i,j), (j,i)]:
            # FIX #1: this needed to be changed from ...(data[x], data[y],...)
            axes[x,y].plot(data[y], data[x], **kwargs)

    # Label the diagonal subplots...
    if not names:
        names = ['x'+str(i) for i in range(numvars)]

    for i, label in enumerate(names):
        axes[i,i].annotate(label, (0.5, 0.5), xycoords='axes fraction',
                ha='center', va='center')

    # Turn on the proper x or y axes ticks.
    for i, j in zip(range(numvars), itertools.cycle((-1, 0))):
        axes[j,i].xaxis.set_visible(True)
        axes[i,j].yaxis.set_visible(True)

    # FIX #2: if numvars is odd, the bottom right corner plot doesn't have the
    # correct axes limits, so we pull them from other axes
    if numvars%2:
        xlimits = axes[0,-1].get_xlim()
        ylimits = axes[-1,0].get_ylim()
        axes[-1,-1].set_xlim(xlimits)
        axes[-1,-1].set_ylim(ylimits)

    return fig

if __name__=='__main__':
    np.random.seed(1977)
    numvars, numdata = 4, 10
    data = 10 * np.random.random((numvars, numdata))
    fig = scatterplot_matrix(data, ['mpg', 'disp', 'drat', 'wt'],
            linestyle='none', marker='o', color='black', mfc='none')
    fig.suptitle('Simple Scatterplot Matrix')
    plt.show()
Run Code Online (Sandbox Code Playgroud)

再次感谢您与我们分享.我已经多次使用它了!哦,我重新安排了main()代码的一部分,以便它可以是一个正式的示例代码,如果它被导入到另一段代码中则不会被调用.


omu*_*mun 5

在阅读问题时,我希望看到包括rpy 的答案。我认为这是一个利用两种美丽语言的不错选择。所以这里是:

import rpy
import numpy as np

def main():
    np.random.seed(1977)
    numvars, numdata = 4, 10
    data = 10 * np.random.random((numvars, numdata))
    mpg = data[0,:]
    disp = data[1,:]
    drat = data[2,:]
    wt = data[3,:]
    rpy.set_default_mode(rpy.NO_CONVERSION)

    R_data = rpy.r.data_frame(mpg=mpg,disp=disp,drat=drat,wt=wt)

    # Figure saved as eps
    rpy.r.postscript('pairsPlot.eps')
    rpy.r.pairs(R_data,
       main="Simple Scatterplot Matrix Via RPy")
    rpy.r.dev_off()

    # Figure saved as png
    rpy.r.png('pairsPlot.png')
    rpy.r.pairs(R_data,
       main="Simple Scatterplot Matrix Via RPy")
    rpy.r.dev_off()

    rpy.set_default_mode(rpy.BASIC_CONVERSION)


if __name__ == '__main__': main()
Run Code Online (Sandbox Code Playgroud)

我无法发布图像来显示结果:( 抱歉!