小编max*_*max的帖子

缩放在numpy中在3D阵列上广播操作的时间

我试图在两个3D阵列上播放">"的简单操作.一个具有另一个维度(m,1,n)(1,m,n).如果我改变第三维(n)的值,我会天真地期望计算的速度将缩放为n.

然而,当我尝试明确地测量它时,我发现当将n从1增加到2时,计算时间增加约10倍,之后缩放是线性的.

当从n = 1到n = 2时,为什么计算时间会急剧增加?我假设它是numpy中的内存管理工件,但我正在寻找更具体的内容.

代码附在下面,附带结果图.

import numpy as np
import time
import matplotlib.pyplot as plt

def compute_time(n):

    x, y = (np.random.uniform(size=(1, 1000, n)), 
            np.random.uniform(size=(1000, 1, n)))

    t = time.time()
    x > y 
    return time.time() - t

a = [
        [
            n, np.asarray([compute_time(n) 
            for _ in range(100)]).mean()
        ]
        for n in range(1, 30, 1)
    ]

a = np.asarray(a)
plt.plot(a[:, 0], a[:, 1])
plt.xlabel('n')
plt.ylabel('time(ms)')
plt.show()
Run Code Online (Sandbox Code Playgroud)

广播操作的时间图

在此输入图像描述

python numpy broadcasting numpy-ufunc

6
推荐指数
2
解决办法
185
查看次数

标签 统计

broadcasting ×1

numpy ×1

numpy-ufunc ×1

python ×1