假设我有x和y与权重向量的载体wgt.我可以y = a x^3 + b x^2 + c x + d使用np.polyfit如下拟合三次曲线():
y_fit = np.polyfit(x, y, deg=3, w=wgt)
Run Code Online (Sandbox Code Playgroud)
现在,假设我想要做的又发作了,但是这一次,我想拟合通过0(即y = a x^3 + b x^2 + c x,d = 0),我怎么能(即指定一个特定的系数,d在这种情况下)是零?
谢谢
您可以np.linalg.lstsq手动使用和构建系数矩阵。首先,我将创建示例数据x和y,以及“精确拟合” y0:
import numpy as np
import matplotlib.pyplot as plt
x = np.arange(100)
y0 = 0.07 * x ** 3 + 0.3 * x ** 2 + 1.1 * x
y = y0 + 1000 * np.random.randn(x.shape[0])
Run Code Online (Sandbox Code Playgroud)
现在我将创建一个包含常量d列的完整三次多项式“训练”或“自变量”矩阵。
XX = np.vstack((x ** 3, x ** 2, x, np.ones_like(x))).T
Run Code Online (Sandbox Code Playgroud)
让我们看看如果我计算与此数据集的拟合并将其与以下内容进行比较会得到什么polyfit:
p_all = np.linalg.lstsq(X_, y)[0]
pp = np.polyfit(x, y, 3)
print np.isclose(pp, p_all).all()
# Returns True
Run Code Online (Sandbox Code Playgroud)
我使用的地方np.isclose是因为这两种算法确实产生了非常小的差异。
您可能在想“这很好,但我仍然没有回答这个问题”。从这里开始,强制拟合具有零偏移与np.ones从数组中删除列相同:
p_no_offset = np.linalg.lstsq(XX[:, :-1], y)[0] # use [0] to just grab the coefs
Run Code Online (Sandbox Code Playgroud)
好的,让我们看看这个拟合与我们的数据相比是什么样的:
y_fit = np.dot(p_no_offset, XX[:, :-1].T)
plt.plot(x, y0, 'k-', linewidth=3)
plt.plot(x, y_fit, 'y--', linewidth=2)
plt.plot(x, y, 'r.', ms=5)
Run Code Online (Sandbox Code Playgroud)
这给出了这个数字,
警告:当对实际未通过 (x,y)=(0,0) 的数据使用此方法时,您将偏置您对输出解系数 (
p) 的估计,因为lstsq将试图补偿这一事实,即存在在您的数据中偏移。有点像“方钉圆孔”问题。
此外,您还可以仅通过执行以下操作将数据拟合为立方体:
p_ = np.linalg.lstsq(X_[:1, :], y)[0]
Run Code Online (Sandbox Code Playgroud)
此处再次适用上述警告。如果您的数据包含二次项、线性项或常数项,则三次系数的估计将有偏差。有时 - 对于数值算法 - 这种事情是有用的,但出于统计目的,我的理解是包括所有较低的术语很重要。如果测试结果表明较低的项在统计上与零没有区别,那很好,但是为了安全起见,您可能应该在估计三次方时将它们保留下来。
祝你好运!
小智 5
您可以尝试如下操作:
curve_fit从导入scipy,即
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
import numpy as np
Run Code Online (Sandbox Code Playgroud)
定义曲线拟合函数。就你而言,
def fit_func(x, a, b, c):
# Curve fitting function
return a * x**3 + b * x**2 + c * x # d=0 is implied
Run Code Online (Sandbox Code Playgroud)
执行曲线拟合,
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
import numpy as np
Run Code Online (Sandbox Code Playgroud)
如果你愿意的话,绘制结果,
plt.plot(x, y, '.r') # Data
plt.plot(x_fit, y_fit, 'k') # Fitted curve
Run Code Online (Sandbox Code Playgroud)
它并没有回答问题,因为它使用numpyspolyfit函数穿过原点,但它解决了问题。
希望有人觉得它有用:)