当输入是许多相同的数组时,使 np.einsum 更快?(或任何其他更快的方法)

sem*_*mmo 6 python numpy numpy-einsum

我有一段代码类型:

nnt = np.real(np.einsum('xa,xb,yc,yd,abcde->exy',evec,evec,evec,evec,quartic))
Run Code Online (Sandbox Code Playgroud)

其中evec是(比如说)一个 L x Lnp.float32阵列,并且quartic是一个 L x L x L x L x Tnp.complex64阵列。

我发现这个例程相当慢。

我认为由于所有evec的都是相同的,可能有更快的方法吗?

提前致谢。

Ehs*_*san 3

首先,您可以重复使用第一个计算:

evec2 = np.real(np.einsum('xa,xb->xab',evec,evec))
nnt = np.real(np.einsum('xab,ycd,abcde->exy',evec2,evec2,quartic))
Run Code Online (Sandbox Code Playgroud)

如果您不关心内存而只需要性能:

evec2 = np.real(np.einsum('xa,xb->xab',evec,evec))
nnt = np.real(np.einsum('xab,ycd,abcde->exy',evec2,evec2,quartic,optimize=True))
Run Code Online (Sandbox Code Playgroud)