如何从numpy数组的每一行只获取第一个True值?

Chr*_*per 4 python arrays numpy

我有一个4x3布尔numpy数组,我试图返回一个大小相同的数组,除了原始每一行上第一个True值的位置外,它都是False.所以如果我有一个起始数组

all_bools = np.array([[False, True, True],[True, True, True],[False, False, True],[False,False,False]])
all_bools
array([[False,  True,  True], # First true value = index 1
       [ True,  True,  True], # First true value = index 0
       [False, False,  True], # First true value = index 2
       [False, False, False]]) # No True Values
Run Code Online (Sandbox Code Playgroud)

然后我想回来

[[False, True, False],
 [True, False, False],
 [False, False, True],
 [False, False, False]]
Run Code Online (Sandbox Code Playgroud)

所以前三行的索引1,0和2已经设置为True而没有别的.基本上,原始方式中的任何True值(超出每行的第一个值)都设置为False.

我一直在用np.where和np.argmax来摆弄这个问题,我还没有找到一个好的解决方案 - 任何帮助都会感激不尽.这需要运行很多次,所以我想避免迭代.

cs9*_*s95 6

您可以使用cumsum,并通过将结果与1进行比较来查找第一个bool.

all_bools.cumsum(axis=1).cumsum(axis=1) == 1 
array([[False,  True, False],
       [ True, False, False],
       [False, False,  True],
       [False, False, False]])
Run Code Online (Sandbox Code Playgroud)

这也解释了@a_guest指出的问题.cumsum需要进行第二次调用以避免匹配False第一个和第二个值之间的所有值True.


如果性能很重要,请使用argmax和设置值:

y = np.zeros_like(all_bools, dtype=bool)
idx = np.arange(len(x)), x.argmax(axis=1)
y[idx] = x[idx]

y
array([[False,  True, False],
       [ True, False, False],
       [False, False,  True],
       [False, False, False]])
Run Code Online (Sandbox Code Playgroud)

Perfplot Performance Timings
我将借此机会展示perfplot一些时间,因为很高兴看到我们的解决方案如何随着不同大小的输入而变化.

import numpy as np
import perfplot

def cs1(x):
    return  x.cumsum(axis=1).cumsum(axis=1) == 1 

def cs2(x):
    y = np.zeros_like(x, dtype=bool)
    idx = np.arange(len(x)), x.argmax(axis=1)
    y[idx] = x[idx]
    return y

def a_guest(x):
    b = np.zeros_like(x, dtype=bool)
    i = np.argmax(x, axis=1)
    b[np.arange(i.size), i] = np.logical_or.reduce(x, axis=1)
    return b

perfplot.show(
    setup=lambda n: np.random.randint(0, 2, size=(n, n)).astype(bool),
    kernels=[cs1, cs2, a_guest],
    labels=['cs1', 'cs2', 'a_guest'],
    n_range=[2**k for k in range(1, 8)],
    xlabel='N'
)
Run Code Online (Sandbox Code Playgroud)

在此输入图像描述

趋势cumsum延续到更大的N. 是非常昂贵的,而我的第二个解决方案和@ a_guest之间存在恒定的时间差.