小编Int*_*ral的帖子

由于大量 Numpy 点调用而最小化开销

我的问题如下,我有一个迭代算法,这样在每次迭代时它需要执行几个矩阵乘法点(A_iB_i),对于 i = 1 ... k。由于这些乘法是用 Numpy 的点执行的,我知道他们正在调用 BLAS-3 实现,这非常快。问题是调用数量巨大,结果证明是我程序中的瓶颈。我想通过制作更少的产品但使用更大的矩阵来最小化所有这些调用的开销。

为简单起见,考虑所有矩阵都是 nxn(通常 n 不大,范围在 1 到 1000 之间)。解决我的问题的一种方法是考虑块对角矩阵 diag( A_i ) 并执行下面的乘积。

diag_blk

这只是对函数 dot 的一次调用,但现在程序浪费了很多时间执行与零的乘法。这个想法似乎不起作用,但它给出了结果 [ A_1 B_1 , ..., A_k B_k ],即所有产品堆叠在一个大矩阵中。

我的问题是,有没有办法通过单个函数调用计算 [ A_1 B_1 , ..., A_k B_k ] ?或者更重要的是,如何比制作 Numpy 点循环更快地计算这些产品?

performance numpy linear-algebra matrix-multiplication

1
推荐指数
1
解决办法
474
查看次数