Pra*_*ani 4 python arrays numpy linear-algebra jax
这是我的问题。我有两个矩阵A和,分别具有维度和B的复杂条目。(n,n,m,m)(n,n)
以下是我为获取矩阵而执行的操作C-
C = np.sum(B[:,:,None,None]*A, axis=(0,1))
Run Code Online (Sandbox Code Playgroud)
上述计算一次大约需要6-8秒。由于我必须计算很多这样的Cs,因此需要花费很多时间。有没有更快的方法来做到这一点?(我在多核 CPU 上使用 JAX NumPy 执行这些操作;普通 NumPy 需要更长的时间)
n=77并且m=512,如果您想知道的话。当我在集群上工作时,我可以并行化,但是数组的巨大大小会消耗大量内存。
看起来你想要einsum:
C = np.einsum('ijkl,ij->kl', A, B)
Run Code Online (Sandbox Code Playgroud)
在 Colab CPU 上使用 numpy 我得到这个:
import numpy as np
x = np.random.rand(50, 50, 500, 500)
y = np.random.rand(50, 50)
def f1(x, y):
return np.sum(y[:,:,None,None]*x, axis=(0,1))
def f2(x, y):
return np.einsum('ijkl,ij->kl', x, y)
np.testing.assert_allclose(f1(x, y), f2(x, y))
%timeit f1(x, y)
# 1 loop, best of 5: 1.52 s per loop
%timeit f2(x, y)
# 1 loop, best of 5: 620 ms per loop
Run Code Online (Sandbox Code Playgroud)