在绘制seaborn中的回归时如何获得数值拟合结果?

The*_*ton 40 python seaborn

如果我使用Python中的seaborn库来绘制线性回归的结果,有没有办法找出回归的数值结果?例如,我可能想知道拟合系数或拟合的R 2.

我可以使用底层的statsmodels接口重新运行相同的拟合,但这似乎是不必要的重复工作,无论如何我想要能够比较结果系数,以确保数值结果与我的相同我在情节中看到了.

mwa*_*kom 21

没有办法做到这一点.

在我看来,要求可视化库为您提供统计建模结果是倒退的.statsmodels,一个建模库,可以让您拟合模型,然后绘制一个与您所适合的模型完全对应的图.如果你想要那种确切的对应关系,这个操作顺序对我来说更有意义.

你可能会说"但是这些情节statsmodels没有那么多的审美选择seaborn".但我认为这是有道理的 - statsmodels是一个建模库,有时在建模服务中使用可视化.seaborn是一个可视化库,有时在可视化服务中使用建模.专业化很好,尝试做所有事情都很糟糕.

幸运的是,seabornstatsmodels整齐的数据.这意味着您真的需要很少的重复工作来通过适当的工具获得图表和模型.

  • @mwaskom,我刚收到通知这个问题已经获得了2500次观看.只是一个数据点,以防你想知道有多少人在寻找这个功能. (12认同)
  • @ user333700为什么要运行两次回归?Seaborn已经为你开车了,只是忘了告诉你停在哪里.它只是给你一个快照,祝你好运 (7认同)
  • 即使对于可视化包来说,这似乎也是一个基本要求。在大多数情况下,在不报告 p 值、r^2 值和系数的情况下呈现数字是不可接受的。我不认为这是一个专门的功能。正如其他人在评论中提到的那样,这确实使得 seaborn 回归对于任何合法目的(例如研究文章)毫无用处。 (4认同)
  • 好,谢谢。如果可以的话,我会再次+1。FWIW,我想到的用例是,如果我使用seaborn进行绘制并将其放置在PowerPoint中。如果有人问“适合的R平方是多少?” 或类似的东西,我希望能够自信地给出一个与我向他们展示的情节完全匹配的答案。 (3认同)
  • 同意,@ user333700.我目前没有使用seaborn,因为这个限制,虽然我可能会看看它.如果现在没有办法,我可能会建议一个功能,其中来自statsmodels的fit对象可以用作适当的seaborn绘图函数的输入. (2认同)
  • 仍然相关。我曾经对seaborn的回归感到信任,但是由于我无法检查所使用的参数,因此其中的意义不大……很高兴知道自己做会更好。少用一个库... (2认同)
  • @mwaskom:我很欣赏你的立场,但假设我遵循你的建议:我对我的数据运行统计模型(“sm”)回归,但图中没有给定的估计器,也没有保证显示的估计器是那个估计器在 sm 中找到。作为一个新手,我只是想用 numpy/seaborn 替换 google 表格,并且 google 表格有一个我*一直使用*的“显示公式”选项。在上述范围的另一端,人们需要在研究论文中使用它。http://support.google.com/docs/thread/11313654?msgid=11379847 您愿意接受“show_equation”拉取请求吗? (2认同)

Leg*_*e17 12

不幸的是,Seaborn 的创建者表示他不会添加这样的功能。下面是一些选项。(最后一部分包含我最初的建议,这是一个使用私有实现细节的黑客,seaborn不是特别灵活。)

简单的替代版本 regplot

以下函数在散点图上叠加一条拟合线,并从 返回结果statsmodels。这支持 最简单也可能是最常见的用法sns.regplot,但没有实现任何更高级的功能。

import statsmodels.api as sm


def simple_regplot(
    x, y, n_std=2, n_pts=100, ax=None, scatter_kws=None, line_kws=None, ci_kws=None
):
    """ Draw a regression line with error interval. """
    ax = plt.gca() if ax is None else ax

    # calculate best-fit line and interval
    x_fit = sm.add_constant(x)
    fit_results = sm.OLS(y, x_fit).fit()

    eval_x = sm.add_constant(np.linspace(np.min(x), np.max(x), n_pts))
    pred = fit_results.get_prediction(eval_x)

    # draw the fit line and error interval
    ci_kws = {} if ci_kws is None else ci_kws
    ax.fill_between(
        eval_x[:, 1],
        pred.predicted_mean - n_std * pred.se_mean,
        pred.predicted_mean + n_std * pred.se_mean,
        alpha=0.5,
        **ci_kws,
    )
    line_kws = {} if line_kws is None else line_kws
    h = ax.plot(eval_x[:, 1], pred.predicted_mean, **line_kws)

    # draw the scatterplot
    scatter_kws = {} if scatter_kws is None else scatter_kws
    ax.scatter(x, y, c=h[0].get_color(), **scatter_kws)

    return fit_results
Run Code Online (Sandbox Code Playgroud)

结果来自statsmodels包含丰富的信息,例如

>>> print(fit_results.summary())

                            OLS Regression Results                            
==============================================================================
Dep. Variable:                      y   R-squared:                       0.477
Model:                            OLS   Adj. R-squared:                  0.471
Method:                 Least Squares   F-statistic:                     89.23
Date:                Fri, 08 Jan 2021   Prob (F-statistic):           1.93e-15
Time:                        17:56:00   Log-Likelihood:                -137.94
No. Observations:                 100   AIC:                             279.9
Df Residuals:                      98   BIC:                             285.1
Df Model:                           1                                         
Covariance Type:            nonrobust                                         
==============================================================================
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
const         -0.1417      0.193     -0.735      0.464      -0.524       0.241
x1             3.1456      0.333      9.446      0.000       2.485       3.806
==============================================================================
Omnibus:                        2.200   Durbin-Watson:                   1.777
Prob(Omnibus):                  0.333   Jarque-Bera (JB):                1.518
Skew:                          -0.002   Prob(JB):                        0.468
Kurtosis:                       2.396   Cond. No.                         4.35
==============================================================================

Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
Run Code Online (Sandbox Code Playgroud)

直接替代(几乎)用于 sns.regplot

与我下面的原始答案相比,上述方法的优点是很容易将其扩展到更复杂的拟合。

无耻的插件:这是regplot我编写的一个扩展函数,它实现了 的大部分sns.regplot功能:https : //github.com/ttesileanu/pydove

虽然仍然缺少一些功能,但我编写的功能

  • 通过将绘图与统计建模分开来提供灵活性(并且您还可以轻松访问拟合结果)。
  • 对于大型数据集,速度要快得多,因为它可以statsmodels计算置信区间而不是使用引导。
  • 允许稍微多样化的拟合(例如, 中的多项式log(x))。
  • 允许稍微细粒度的绘图选项。

旧答案

不幸的是,Seaborn 的创建者表示他不会添加这样的功能,所以这里有一个解决方法。

def regplot(
    *args,
    line_kws=None,
    marker=None,
    scatter_kws=None,
    **kwargs
):
    # this is the class that `sns.regplot` uses
    plotter = sns.regression._RegressionPlotter(*args, **kwargs)

    # this is essentially the code from `sns.regplot`
    ax = kwargs.get("ax", None)
    if ax is None:
        ax = plt.gca()

    scatter_kws = {} if scatter_kws is None else copy.copy(scatter_kws)
    scatter_kws["marker"] = marker
    line_kws = {} if line_kws is None else copy.copy(line_kws)

    plotter.plot(ax, scatter_kws, line_kws)

    # unfortunately the regression results aren't stored, so we rerun
    grid, yhat, err_bands = plotter.fit_regression(plt.gca())

    # also unfortunately, this doesn't return the parameters, so we infer them
    slope = (yhat[-1] - yhat[0]) / (grid[-1] - grid[0])
    intercept = yhat[0] - slope * grid[0]
    return slope, intercept
Run Code Online (Sandbox Code Playgroud)

请注意,这仅适用于线性回归,因为它只是从回归结果中推断出斜率和截距。好处是它使用了seaborn自己的回归类,因此可以保证结果与显示的一致。缺点当然是我们使用了一个私有的实现细节,seaborn它可以在任何时候中断。