pal*_*ago 5 python numpy vectorization
我在Python中编写了一些代码很好但速度很慢; 我认为由于for循环.我希望可以使用numpy命令加快以下操作.让我来定义目标.
让我们假设我有一个all_CMs维度为rowx 的2D numpy数组col.例如,考虑一个6x 11数组(见下图).
我想计算的平均对所有行,即总和 ⱼaᵢⱼ导致数组.这当然可以轻松完成.(我叫这个值CM_tilde)
现在,对于每一行,我想计算一些选定值的平均值,即通过计算它们的总和并将其除以所有列的数量(N)来计算低于某个阈值的所有值.如果该值高于此定义的阈值,则添加CM_tilde值(整行的平均值).调用此值CM
然后,CM从行中的每个元素中减去该值
除此之外,我想要一个numpy数组或列表,其中CM列出了所有这些值.
图:
以下代码工作但很慢(特别是如果数组变大)
CM_tilde = np.mean(data, axis=1)
N = data.shape[1]
data_cm = np.zeros(( data.shape[0], data.shape[1], data.shape[2] ))
all_CMs = np.zeros(( data.shape[0], data.shape[2]))
for frame in range(data.shape[2]):
for row in range(data.shape[0]):
CM=0
for col in range(data.shape[1]):
if data[row, col, frame] < (CM_tilde[row, frame]+threshold):
CM += data[row, col, frame]
else:
CM += CM_tilde[row, frame]
CM = CM/N
all_CMs[row, frame] = CM
# calculate CM corrected value
for col in range(data.shape[1]):
data_cm[row, col, frame] = data[row, col, frame] - CM
print "frame: ", frame
return data_cm, all_CMs
Run Code Online (Sandbox Code Playgroud)
有任何想法吗?
And*_*eak 15
向量化你正在做的事情很容易:
import numpy as np
#generate dummy data
nrows=6
ncols=11
nframes=3
threshold=0.3
data=np.random.rand(nrows,ncols,nframes)
CM_tilde = np.mean(data, axis=1)
N = data.shape[1]
all_CMs2 = np.mean(np.where(data < (CM_tilde[:,None,:]+threshold),data,CM_tilde[:,None,:]),axis=1)
data_cm2 = data - all_CMs2[:,None,:]
Run Code Online (Sandbox Code Playgroud)
将其与您的原件进行比较:
In [684]: (data_cm==data_cm2).all()
Out[684]: True
In [685]: (all_CMs==all_CMs2).all()
Out[685]: True
Run Code Online (Sandbox Code Playgroud)
逻辑是我们[nrows,ncols,nframes]同时使用大小的数组.主要技巧是通过将CM_tilde大小[nrows,nframes]转换CM_tilde[:,None,:]为大小来利用python的广播[nrows,1,nframes].然后,Python将为每列使用相同的值,因为这是此修改后的单个维度CM_tilde.
通过使用np.where我们选择(基于threshold)我们是否想要获得相应的值data,或者再次获得广播值CM_tilde.一个新的用途np.mean允许我们计算all_CMs2.
在最后一步中,我们通过直接all_CMs2从相应的元素中减去这个新元素来利用广播data.
通过查看临时变量的隐式索引,这可能有助于以这种方式向量化代码.我的意思是你的临时变量CM存在于循环中[nrows,nframes],并且每次迭代都会重置其值.这意味着实际上CM是一个数量CM[row,frame](后来明确分配给2d数组all_CMs),从这里很容易看出你可以通过CMtmp[row,col,frames]在其列维度上总结一个合适的数量来构造它.如果有帮助,您可以为此目的命名np.where(...)部件CMtmp,然后np.mean(CMtmp,axis=1)从中进行计算.显然,结果相同,但可能更透明.