小编BPr*_*ent的帖子

计算 ~1m Hermitian 矩阵的谱范数:“numpy.linalg.norm”太慢

我想计算N个 8x8 Hermitian 矩阵的谱范数,其中N接近 1E6。以这 100 万个随机复数 8x8 矩阵为例:

import numpy as np

array = np.random.rand(8,8,1e6)  + 1j*np.random.rand(8,8,1e6)
Run Code Online (Sandbox Code Playgroud)

目前我使用以下命令需要花费近 10 秒的时间numpy.linalg.norm

np.linalg.norm(array, ord=2, axis=(0,1))
Run Code Online (Sandbox Code Playgroud)

我尝试使用下面的 Cython 代码,但这只给我带来了可以忽略不计的性能改进:

import numpy as np
cimport numpy as np
cimport cython

np.import_array()

DTYPE = np.complex64

@cython.boundscheck(False)
@cython.wraparound(False)
def function(np.ndarray[np.complex64_t, ndim=3] Array):
    assert Array.dtype == DTYPE
    cdef int shape0 = Array.shape[2]
    cdef np.ndarray[np.float32_t, ndim=1] normarray = np.zeros(shape0, dtype=np.float32)
    normarray = np.linalg.norm(Array, ord=2, axis=(0, 1))
    return normarray
Run Code Online (Sandbox Code Playgroud)

我还尝试了 numba 和其他一些 …

python numpy linear-algebra cython numba

5
推荐指数
1
解决办法
3204
查看次数

标签 统计

cython ×1

linear-algebra ×1

numba ×1

numpy ×1

python ×1