如何将某些系数约束的多项式拟合?

Jen*_*ars 4 python numpy curve-fitting scipy polynomials

使用NumPy polyfit(或类似的东西)是否有一种简单的方法来获得将一个或多个系数限制为特定值的解决方案?

例如,我们可以使用以下公式找到普通的多项式拟合:

x = np.array([0.0, 1.0, 2.0, 3.0,  4.0,  5.0])
y = np.array([0.0, 0.8, 0.9, 0.1, -0.8, -1.0])
z = np.polyfit(x, y, 3)
Run Code Online (Sandbox Code Playgroud)

屈服

array([ 0.08703704, -0.81349206,  1.69312169, -0.03968254])
Run Code Online (Sandbox Code Playgroud)

但是,如果我想要最合适的多项式,其中第三个系数(在上述情况下z[2])要求为1,该怎么办?还是我需要从头开始编写配件?

Cle*_*leb 5

在这种情况下,我将使用curve_fitlmfit;。我很快就展示了它的第一个。

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit

def func(x, a, b, c, d):
  return a + b * x + c * x ** 2 + d * x ** 3

x = np.array([0.0, 1.0, 2.0, 3.0,  4.0,  5.0])
y = np.array([0.0, 0.8, 0.9, 0.1, -0.8, -1.0])

print(np.polyfit(x, y, 3))

popt, _ = curve_fit(func, x, y)
print(popt)

popt_cons, _ = curve_fit(func, x, y, bounds=([-np.inf, 2, -np.inf, -np.inf], [np.inf, 2.001, np.inf, np.inf]))
print(popt_cons)

xnew = np.linspace(x[0], x[-1], 1000)

plt.plot(x, y, 'bo')
plt.plot(xnew, func(xnew, *popt), 'k-')
plt.plot(xnew, func(xnew, *popt_cons), 'r-')
plt.show()
Run Code Online (Sandbox Code Playgroud)

这将打印:

[ 0.08703704 -0.81349206  1.69312169 -0.03968254]
[-0.03968254  1.69312169 -0.81349206  0.08703704]
[-0.14331349  2.         -0.95913556  0.10494372]
Run Code Online (Sandbox Code Playgroud)

因此,在无约束的情况下,polyfitcurve_fit给出相同的结果(只是顺序不同),在有约束的情况下,固定参数为2。

该图如下所示:

在此处输入图片说明

在其中,lmfit您还可以选择是否适合某个参数,因此您也可以将其设置为所需的值。


小智 5

抱歉复活了

..但我觉得缺少这个答案。

为了拟合多项式,我们求解以下方程组:

a0*x0^n + a1*x0^(n-1) .. + an*x0^0 = y0
a0*x1^n + a1*x1^(n-1) .. + an*x1^0 = y1
                 ...
a0*xm^n + a1*xm^(n-1) .. + an*xm^0 = ym
Run Code Online (Sandbox Code Playgroud)

这是一个形式问题V @ a = y

其中“V”是范德蒙德矩阵:

[[x0^n  x0^(n-1)  1],
 [x1^n  x1^(n-1)  1],
        ...
 [xm^n  xm^(n-1)  1]]
Run Code Online (Sandbox Code Playgroud)

“y”是保存 y 值的列向量:

[[y0],
 [y1],
  ...
 [ym]]
Run Code Online (Sandbox Code Playgroud)

..“a”是我们正在求解的系数的列向量:

[[a0],
 [a1],
  ...
 [an]]
Run Code Online (Sandbox Code Playgroud)

该问题可以使用线性最小二乘法解决,如下所示:

import numpy as np

x = np.array([0.0, 1.0, 2.0, 3.0,  4.0,  5.0])
y = np.array([0.0, 0.8, 0.9, 0.1, -0.8, -1.0])

deg = 3
V = np.vander(x, deg + 1)
z, *_ = np.linalg.lstsq(V, y, rcond=None)

print(z)
# [ 0.08703704 -0.81349206  1.69312169 -0.03968254]
Run Code Online (Sandbox Code Playgroud)

..它产生与 polyfit 方法相同的解决方案:

z = np.polyfit(x, y, deg)

print(z)
# [ 0.08703704 -0.81349206  1.69312169 -0.03968254]
Run Code Online (Sandbox Code Playgroud)

相反,我们想要一个解决方案a2 = 1

从答案的开头代入a2 = 1方程组,然后将相应项从左移到右,得到:

a0*x0^n + a1*x0^(n-1) + 1*x0^(n-2) .. + an*x0^0 = y0
a0*x1^n + a1*x1^(n-1) + 1*x0^(n-2) .. + an*x1^0 = y1
                 ...
a0*xm^n + a1*xm^(n-1) + 1*x0^(n-2) .. + an*xm^0 = ym

=>

a0*x0^n + a1*x0^(n-1) .. + an*x0^0 = y0 - 1*x0^(n-2)
a0*x1^n + a1*x1^(n-1) .. + an*x1^0 = y1 - 1*x0^(n-2)
                 ...
a0*xm^n + a1*xm^(n-1) .. + an*xm^0 = ym - 1*x0^(n-2)
Run Code Online (Sandbox Code Playgroud)

这对应于从 Vandermonde 矩阵中删除第 2 列并从 y 向量中减去它,如下所示:

y_ = y - V[:, 2]
V_ = np.delete(V, 2, axis=1)
z_, *_ = np.linalg.lstsq(V_, y_, rcond=None)
z_ = np.insert(z_, 2, 1)

print(z_)
# [ 0.04659264 -0.48453866  1.          0.19438046]
Run Code Online (Sandbox Code Playgroud)

请注意,在求解线性最小二乘问题后,我在系数向量中插入了 1,我们不再求解,因为a2我们将其设置为 1 并将其从问题中删除。

为了完整起见,绘制时的解决方案如下所示:

三种不同方法的绘图

以及我使用的完整代码:


import numpy as np

x = np.array([0.0, 1.0, 2.0, 3.0,  4.0,  5.0])
y = np.array([0.0, 0.8, 0.9, 0.1, -0.8, -1.0])

deg = 3
V = np.vander(x, deg + 1)
z, *_ = np.linalg.lstsq(V, y, rcond=None)

print(z)
# [ 0.08703704 -0.81349206  1.69312169 -0.03968254]

z = np.polyfit(x, y, deg)

print(z)
# [ 0.08703704 -0.81349206  1.69312169 -0.03968254]

y_ = y - V[:, 2]
V_ = np.delete(V, 2, axis=1)
z_, *_ = np.linalg.lstsq(V_, y_, rcond=None)
z_ = np.insert(z_, 2, 1)

print(z_)
# [ 0.04659264 -0.48453866  1.          0.19438046]

from matplotlib import pyplot as plt

plt.plot(x, y, 'o', label='data')
plt.plot(x, V @ z, label='polyfit')
plt.plot(x, V @ z_, label='constrained (a2 = 0)')

plt.legend()

plt.show()
Run Code Online (Sandbox Code Playgroud)