我有一个1650行和1275列的numpy数组,包含0和255.我想得到行中每个第一个零的索引并将其存储在一个数组中.我用循环来实现这一点.这是示例代码
#new_arr is a numpy array and k is an empty array
for i in range(new_arr.shape[0]):
if not np.all(new_arr[i,:]) == 255:
x = np.where(new_arr[i,:]==0)[0][0]
k.append(x)
else:
k.append(-1)
Run Code Online (Sandbox Code Playgroud)
1650行需要大约1.3秒.有没有其他方法或函数以更快的方式获取索引数组?
一种方法是获取匹配的掩码,==0然后argmax相互argmax(axis=1)匹配,即为每行提供第一个匹配索引 -
(arr==0).argmax(axis=1)
Run Code Online (Sandbox Code Playgroud)
样品运行 -
In [443]: arr
Out[443]:
array([[0, 1, 0, 2, 2, 1, 2, 2],
[1, 1, 2, 2, 2, 1, 0, 1],
[2, 1, 0, 1, 0, 0, 2, 0],
[2, 2, 1, 0, 1, 2, 1, 0]])
In [444]: (arr==0).argmax(axis=1)
Out[444]: array([0, 6, 2, 3])
Run Code Online (Sandbox Code Playgroud)
捕获非零行(如果可以!)
为了方便没有任何零的行,我们需要再做一步工作,一些掩盖 -
In [445]: arr[2] = 9
In [446]: arr
Out[446]:
array([[0, 1, 0, 2, 2, 1, 2, 2],
[1, 1, 2, 2, 2, 1, 0, 1],
[9, 9, 9, 9, 9, 9, 9, 9],
[2, 2, 1, 0, 1, 2, 1, 0]])
In [447]: mask = arr==0
In [448]: np.where(mask.any(1), mask.argmax(1), -1)
Out[448]: array([ 0, 6, -1, 3])
Run Code Online (Sandbox Code Playgroud)