小编Dum*_*r21的帖子

为什么这个函数在 JAX 和 numpy 中变慢?

我有以下 numpy 函数,如下所示,我正在尝试使用 JAX 进行优化,但无论出于何种原因,它都变慢了。

有人可以指出我可以做些什么来提高这里的性能吗?我怀疑这与 Cg_new 发生的列表理解有关,但将其分开并不会在 JAX 中产生任何进一步的性能提升。

import numpy as np 

def testFunction_numpy(C, Mi, C_new, Mi_new):
    Wg_new = np.zeros((len(Mi_new[:,0]), len(Mi[0])))
    Cg_new = np.zeros((1, len(Mi[0])))
    invertCsensor_new = np.linalg.inv(C_new)

    Wg_new = np.dot(invertCsensor_new, Mi_new)
    Cg_new = [np.dot(((-0.5*(Mi_new[:,m].conj().T))), (Wg_new[:,m])) for m in range(0, len(Mi[0]))] 

    return C_new, Mi_new, Wg_new, Cg_new

C = np.random.rand(483,483)
Mi = np.random.rand(483,8)
C_new = np.random.rand(198,198)
Mi_new = np.random.rand(198,8)

%timeit testFunction_numpy(C, Mi, C_new, Mi_new)
#1000 loops, best of 3: 1.73 ms per loop
Run Code Online (Sandbox Code Playgroud)

这是 JAX 等效项:

import jax.numpy as …
Run Code Online (Sandbox Code Playgroud)

python optimization performance numpy jax

4
推荐指数
1
解决办法
1243
查看次数

标签 统计

jax ×1

numpy ×1

optimization ×1

performance ×1

python ×1