Oka*_*ano 1 python numpy julia
我有一个简单的代码:
读取轨迹文件,该文件可以看作存储在 Y 中的二维数组列表(空间中的位置列表)
然后我想计算每对(scipy.pdist 样式)的 RMSD
我的代码工作正常:
trajectory = read("test.lammpstrj", index="::")
m = len(trajectory)
#.get_positions() return a 2d numpy array
Y = np.array([snapshot.get_positions() for snapshot in trajectory])
b = [np.sqrt(((((Y[i]- Y[j])**2))*3).mean()) for i in range(m) for j in range(i + 1, m)]
Run Code Online (Sandbox Code Playgroud)
此代码使用 python3.10 在 0.86 秒内执行,使用 Julia1.8 同类代码在 0.46 秒内执行
我计划拥有更大的轨迹(~ 200,000 个元素),是否可以使用 python 获得加速,还是应该坚持使用 Julia?
您已经提到snapshot.get_positions()返回一些二维数组,假设为 shape (p, q)。所以我期望这Y是一个具有某种形状的 3D 数组(m, p, q),其中m是轨迹中快照的数量。您还期望m规模相当高。
让我们看看在设置上加快距离计算的基本方法m=1000:
import numpy as np\n\n# dummy inputs\nm = 1000\np, q = 4, 5\nY = np.random.randn(m, p, q)\n\n# your current method\ndef foo():\n return [np.sqrt(((((Y[i]- Y[j])**2))*3).mean()) for i in range(m) for j in range(i + 1, m)]\n\n# vectorized approach -> compute the upper triangle of the pairwise distance matrix\ndef bar():\n u, v = np.triu_indices(Y.shape[0], 1)\n return np.sqrt((3 * (Y[u] - Y[v]) ** 2).mean(axis=(-1, -2)))\n\n# Check for correctness\n\nout_1 = foo()\nout_2 = bar()\nprint(np.allclose(out_1, out_2))\n# True\nRun Code Online (Sandbox Code Playgroud)\n如果我们测试所需的时间:
\n%timeit -n 10 -r 3 foo()\n# 3.16 s \xc2\xb1 50.3 ms per loop (mean \xc2\xb1 std. dev. of 3 runs, 10 loops each)\nRun Code Online (Sandbox Code Playgroud)\n第一种方法真的很慢,这个计算需要3秒多。我们来检查一下第二种方法:
\n%timeit -n 10 -r 3 bar()\n# 97.5 ms \xc2\xb1 405 \xc2\xb5s per loop (mean \xc2\xb1 std. dev. of 3 runs, 10 loops each)\nRun Code Online (Sandbox Code Playgroud)\n所以我们这里有大约 30 倍的加速,这将使你在 python 中的大型计算比使用原始代码更可行。请随意测试其他尺寸,Y看看它与原始尺寸相比如何缩放。
另外,你还可以尝试JIT,主要是jax或numba。bar使用移植该函数相当简单jax.numpy,例如:
import jax\nimport jax.numpy as jnp\n\n@jax.jit\ndef jit_bar(Y):\n u, v = jnp.triu_indices(Y.shape[0], 1)\n return jnp.sqrt((3 * (Y[u] - Y[v]) ** 2).mean(axis=(-1, -2)))\n\n# check for correctness\n\nprint(np.allclose(bar(), jit_bar(Y)))\n# True\nRun Code Online (Sandbox Code Playgroud)\n如果我们测试 jitted jnp 操作的时间:
\n%timeit -n 10 -r 3 jit_bar(Y)\n# 10.6 ms \xc2\xb1 678 \xc2\xb5s per loop (mean \xc2\xb1 std. dev. of 3 runs, 10 loops each)\nRun Code Online (Sandbox Code Playgroud)\n因此,与原始版本相比,我们甚至可以达到约 300 倍的速度。
\n请注意,并非每个操作都可以如此轻松地转换为 jax/jit(这个特定问题很方便),因此一般建议是简单地避免 python 循环并使用 numpy\ 的广播/矢量化功能,例如bar().