在Python中以高精度找到由(x,y)数据给出的两条曲线的交点

Dan*_*ein 5 python algorithm numpy scipy

我有两个数据集:(x,y1)和(x,y2)。我想找到两条曲线相互交叉的位置。目标类似于此问题:在Python中将两个图形相交,找到x值:

但是,此处描述的方法仅找到最接近的数据点的交点。我想找到比原始数据间距更高精度的曲线的交点。一种选择是简单地重新插值到更细的网格。这可行,但是精度由我为重新插值选择的点数确定,这是任意的,并且需要在精度和效率之间进行权衡。

或者,我可以scipy.optimize.fsolve用来查找数据集的两个样条插值的精确交集。这很好用,但是它不能轻易找到多个相交点,需要我对相交点做出合理的猜测,并且可能缩放效果不好。(最终,我想找到几千套(x,y1,y2)的交集,所以一种有效的算法会很好。)

这是我到目前为止所拥有的。有什么改善的想法吗?

import numpy as np
import matplotlib.pyplot as plt
import scipy.interpolate, scipy.optimize

x  = np.linspace(1, 4, 20)
y1 = np.sin(x)
y2 = 0.05*x

plt.plot(x, y1, marker='o', mec='none', ms=4, lw=1, label='y1')
plt.plot(x, y2, marker='o', mec='none', ms=4, lw=1, label='y2')

idx = np.argwhere(np.diff(np.sign(y1 - y2)) != 0)

plt.plot(x[idx], y1[idx], 'ms', ms=7, label='Nearest data-point method')

interp1 = scipy.interpolate.InterpolatedUnivariateSpline(x, y1)
interp2 = scipy.interpolate.InterpolatedUnivariateSpline(x, y2)

new_x = np.linspace(x.min(), x.max(), 100)
new_y1 = interp1(new_x)
new_y2 = interp2(new_x)
idx = np.argwhere(np.diff(np.sign(new_y1 - new_y2)) != 0)
plt.plot(new_x[idx], new_y1[idx], 'ro', ms=7, label='Nearest data-point method, with re-interpolated data')

def difference(x):
    return np.abs(interp1(x) - interp2(x))

x_at_crossing = scipy.optimize.fsolve(difference, x0=3.0)
plt.plot(x_at_crossing, interp1(x_at_crossing), 'cd', ms=7, label='fsolve method')

plt.legend(frameon=False, fontsize=10, numpoints=1, loc='lower left')

plt.savefig('curve crossing.png', dpi=200)
plt.show()
Run Code Online (Sandbox Code Playgroud)

在此处输入图片说明

Dan*_*ein 5

最好(也是最有效)的答案可能取决于数据集及其采样方式。但是,许多数据集的一个很好的近似是它们在数据点之间几乎是线性的。因此,我们可以通过原帖中显示的“最近数据点”方法找到交点的大致位置。然后,我们可以使用线性插值细化最近的两个数据点之间的交点位置。

这种方法非常快,并且适用于 2D numpy 数组,以防您想一次计算多条曲线的交叉(正如我想在我的应用程序中所做的那样)。

(我从“如何计算 Python 中两条线的交点? ”中借用代码用于线性插值。)

from __future__ import division 
import numpy as np
import matplotlib.pyplot as plt

def interpolated_intercept(x, y1, y2):
    """Find the intercept of two curves, given by the same x data"""

    def intercept(point1, point2, point3, point4):
        """find the intersection between two lines
        the first line is defined by the line between point1 and point2
        the first line is defined by the line between point3 and point4
        each point is an (x,y) tuple.

        So, for example, you can find the intersection between
        intercept((0,0), (1,1), (0,1), (1,0)) = (0.5, 0.5)

        Returns: the intercept, in (x,y) format
        """    

        def line(p1, p2):
            A = (p1[1] - p2[1])
            B = (p2[0] - p1[0])
            C = (p1[0]*p2[1] - p2[0]*p1[1])
            return A, B, -C

        def intersection(L1, L2):
            D  = L1[0] * L2[1] - L1[1] * L2[0]
            Dx = L1[2] * L2[1] - L1[1] * L2[2]
            Dy = L1[0] * L2[2] - L1[2] * L2[0]

            x = Dx / D
            y = Dy / D
            return x,y

        L1 = line([point1[0],point1[1]], [point2[0],point2[1]])
        L2 = line([point3[0],point3[1]], [point4[0],point4[1]])

        R = intersection(L1, L2)

        return R

    idx = np.argwhere(np.diff(np.sign(y1 - y2)) != 0)
    xc, yc = intercept((x[idx], y1[idx]),((x[idx+1], y1[idx+1])), ((x[idx], y2[idx])), ((x[idx+1], y2[idx+1])))
    return xc,yc

def main():
    x  = np.linspace(1, 4, 20)
    y1 = np.sin(x)
    y2 = 0.05*x

    plt.plot(x, y1, marker='o', mec='none', ms=4, lw=1, label='y1')
    plt.plot(x, y2, marker='o', mec='none', ms=4, lw=1, label='y2')

    idx = np.argwhere(np.diff(np.sign(y1 - y2)) != 0)

    plt.plot(x[idx], y1[idx], 'ms', ms=7, label='Nearest data-point method')

    # new method!
    xc, yc = interpolated_intercept(x,y1,y2)
    plt.plot(xc, yc, 'co', ms=5, label='Nearest data-point, with linear interpolation')


    plt.legend(frameon=False, fontsize=10, numpoints=1, loc='lower left')

    plt.savefig('curve crossing.png', dpi=200)
    plt.show()

if __name__ == '__main__': 
    main()
Run Code Online (Sandbox Code Playgroud)

曲线交叉

2018 年 12 月 13 日更新:如果需要找到几个拦截,这里是代码的修改版本:

from __future__ import division 
import numpy as np
import matplotlib.pyplot as plt

def interpolated_intercepts(x, y1, y2):
    """Find the intercepts of two curves, given by the same x data"""

    def intercept(point1, point2, point3, point4):
        """find the intersection between two lines
        the first line is defined by the line between point1 and point2
        the first line is defined by the line between point3 and point4
        each point is an (x,y) tuple.

        So, for example, you can find the intersection between
        intercept((0,0), (1,1), (0,1), (1,0)) = (0.5, 0.5)

        Returns: the intercept, in (x,y) format
        """    

        def line(p1, p2):
            A = (p1[1] - p2[1])
            B = (p2[0] - p1[0])
            C = (p1[0]*p2[1] - p2[0]*p1[1])
            return A, B, -C

        def intersection(L1, L2):
            D  = L1[0] * L2[1] - L1[1] * L2[0]
            Dx = L1[2] * L2[1] - L1[1] * L2[2]
            Dy = L1[0] * L2[2] - L1[2] * L2[0]

            x = Dx / D
            y = Dy / D
            return x,y

        L1 = line([point1[0],point1[1]], [point2[0],point2[1]])
        L2 = line([point3[0],point3[1]], [point4[0],point4[1]])

        R = intersection(L1, L2)

        return R

    idxs = np.argwhere(np.diff(np.sign(y1 - y2)) != 0)

    xcs = []
    ycs = []

    for idx in idxs:
        xc, yc = intercept((x[idx], y1[idx]),((x[idx+1], y1[idx+1])), ((x[idx], y2[idx])), ((x[idx+1], y2[idx+1])))
        xcs.append(xc)
        ycs.append(yc)
    return np.array(xcs), np.array(ycs)

def main():
    x  = np.linspace(1, 10, 50)
    y1 = np.sin(x)
    y2 = 0.02*x

    plt.plot(x, y1, marker='o', mec='none', ms=4, lw=1, label='y1')
    plt.plot(x, y2, marker='o', mec='none', ms=4, lw=1, label='y2')

    idx = np.argwhere(np.diff(np.sign(y1 - y2)) != 0)

    plt.plot(x[idx], y1[idx], 'ms', ms=7, label='Nearest data-point method')

    # new method!
    xcs, ycs = interpolated_intercepts(x,y1,y2)
    for xc, yc in zip(xcs, ycs):
        plt.plot(xc, yc, 'co', ms=5, label='Nearest data-point, with linear interpolation')


    plt.legend(frameon=False, fontsize=10, numpoints=1, loc='lower left')

    plt.savefig('curve crossing.png', dpi=200)
    plt.show()

if __name__ == '__main__': 
    main()
Run Code Online (Sandbox Code Playgroud)

`` 在此处输入图片说明