如何在python中的散点图上绘制一条线?

gol*_*ine 58 python numpy matplotlib

我有两个数据向量,我已将它们放入matplotlib.scatter().现在,我想过度拟合这些数据的线性拟合.我该怎么做?我尝试过使用scikitlearnnp.scatter.

Gre*_*ier 100

import numpy as np
from numpy.polynomial.polynomial import polyfit
import matplotlib.pyplot as plt

# Sample data
x = np.arange(10)
y = 5 * x + 10

# Fit with polyfit
b, m = polyfit(x, y, 1)

plt.plot(x, y, '.')
plt.plot(x, b + m * x, '-')
plt.show()
Run Code Online (Sandbox Code Playgroud)

在此输入图像描述


pco*_*ing 30

我偏爱scikits.statsmodels.这是一个例子:

import statsmodels.api as sm
import numpy as np
import matplotlib.pyplot as plt

X = np.random.rand(100)
Y = X + np.random.rand(100)*0.1

results = sm.OLS(Y,sm.add_constant(X)).fit()

print results.summary()

plt.scatter(X,Y)

X_plot = np.linspace(0,1,100)
plt.plot(X_plot, X_plot*results.params[0] + results.params[1])

plt.show()
Run Code Online (Sandbox Code Playgroud)

唯一棘手的部分是sm.add_constant(X)添加一列一列X以获得拦截术语.

     Summary of Regression Results
=======================================
| Dependent Variable:            ['y']|
| Model:                           OLS|
| Method:                Least Squares|
| Date:               Sat, 28 Sep 2013|
| Time:                       09:22:59|
| # obs:                         100.0|
| Df residuals:                   98.0|
| Df model:                        1.0|
==============================================================================
|                   coefficient     std. error    t-statistic          prob. |
------------------------------------------------------------------------------
| x1                      1.007       0.008466       118.9032         0.0000 |
| const                 0.05165       0.005138        10.0515         0.0000 |
==============================================================================
|                          Models stats                      Residual stats  |
------------------------------------------------------------------------------
| R-squared:                     0.9931   Durbin-Watson:              1.484  |
| Adjusted R-squared:            0.9930   Omnibus:                    12.16  |
| F-statistic:                1.414e+04   Prob(Omnibus):           0.002294  |
| Prob (F-statistic):        9.137e-108   JB:                        0.6818  |
| Log likelihood:                 223.8   Prob(JB):                  0.7111  |
| AIC criterion:                 -443.7   Skew:                     -0.2064  |
| BIC criterion:                 -438.5   Kurtosis:                   2.048  |
------------------------------------------------------------------------------
Run Code Online (Sandbox Code Playgroud)

示例情节

  • @David:params数组是错误的.尝试:plt.plot(X_plot,X_plot*results.params [1] + results.params [0]).或者,甚至更好:plt.plot(X,results.fittedvalues)作为第一个公式假设y是线性是x,虽然这里是真的,但并非总是如此. (4认同)
  • 我的身材看起来不同; 线路在错误的地方; 高于分数 (3认同)

小智 27

我喜欢Seaborn的regplotlmplot:

在此输入图像描述

  • 将seaborn导入为sns;sns.regplot(x=x, y=y) (2认同)

1''*_*1'' 20

绘制最佳拟合线的这个优秀答案的单行版本是:

plt.plot(np.unique(x), np.poly1d(np.polyfit(x, y, 1))(np.unique(x)))
Run Code Online (Sandbox Code Playgroud)

使用np.unique(x)而不是x句柄来处理x未排序或具有重复值的情况.

呼叫poly1d是另一种写作,m*x + b就像在其他优秀的答案.


Fra*_*urt 12

另一种方法,使用axes.get_xlim():

import matplotlib.pyplot as plt
import numpy as np

def scatter_plot_with_correlation_line(x, y, graph_filepath):
    '''
    http://stackoverflow.com/a/34571821/395857
    x does not have to be ordered.
    '''
    # Scatter plot
    plt.scatter(x, y)

    # Add correlation line
    axes = plt.gca()
    m, b = np.polyfit(x, y, 1)
    X_plot = np.linspace(axes.get_xlim()[0],axes.get_xlim()[1],100)
    plt.plot(X_plot, m*X_plot + b, '-')

    # Save figure
    plt.savefig(graph_filepath, dpi=300, format='png', bbox_inches='tight')

def main():
    # Data
    x = np.random.rand(100)
    y = x + np.random.rand(100)*0.1

    # Plot
    scatter_plot_with_correlation_line(x, y, 'scatter_plot.png')

if __name__ == "__main__":
    main()
    #cProfile.run('main()') # if you want to do some profiling
Run Code Online (Sandbox Code Playgroud)

在此输入图像描述


tdy*_*tdy 9

matplotlib 3.3 中的新增功能

使用新plt.axline函数绘制y = m*x + b给定的斜率m和截距b

plt.axline(xy1=(0, b), slope=m)
Run Code Online (Sandbox Code Playgroud)

plt.axline的示例np.polyfit

import numpy as np
import matplotlib.pyplot as plt

# generate random vectors
rng = np.random.default_rng(0)
x = rng.random(100)
y = 5*x + rng.rayleigh(1, x.shape)
plt.scatter(x, y, alpha=0.5)

# compute slope m and intercept b
m, b = np.polyfit(x, y, deg=1)

# plot fitted y = m*x + b
plt.axline(xy1=(0, b), slope=m, color='r', label=f'$y = {m:.2f}x {b:+.2f}$')

plt.legend()
plt.show()
Run Code Online (Sandbox Code Playgroud)

这里的方程是图例条目,但如果您想沿着线本身绘制方程,请参阅如何旋转注释以匹配线。