从numpy数组中的最大值进行掩码,特定轴

fee*_*dMe 6 python numpy

输入示例:

我有一个numpy数组,例如

a=np.array([[0,1], [2, 1], [4, 8]])

期望的输出:

我想生成一个掩码数组,其沿给定轴具有最大值,在我的情况下,轴1为True,其他所有为False.例如在这种情况下

mask = np.array([[False, True], [True, False], [False, True]])

尝试:

我尝试过使用方法,np.amax但这会返回展平列表中的最大值:

>>> np.amax(a, axis=1)
array([1, 2, 8])
Run Code Online (Sandbox Code Playgroud)

并且np.argmax类似地返回沿该轴的最大值的索引.

>>> np.argmax(a, axis=1)
array([1, 0, 1])
Run Code Online (Sandbox Code Playgroud)

我可以通过某种方式迭代这个但是一旦这些数组变得更大,我希望解决方案保持原生于numpy.

Div*_*kar 10

方法#1

使用broadcasting,我们可以使用与最大值的比较,同时保持暗淡以方便broadcasting-

a.max(axis=1,keepdims=1) == a
Run Code Online (Sandbox Code Playgroud)

样品运行 -

In [83]: a
Out[83]: 
array([[0, 1],
       [2, 1],
       [4, 8]])

In [84]: a.max(axis=1,keepdims=1) == a
Out[84]: 
array([[False,  True],
       [ True, False],
       [False,  True]], dtype=bool)
Run Code Online (Sandbox Code Playgroud)

方法#2

另外,argmax对于broadcasted-comparison沿列的指数范围的另一个案例的指数 -

In [92]: a.argmax(axis=1)[:,None] == range(a.shape[1])
Out[92]: 
array([[False,  True],
       [ True, False],
       [False,  True]], dtype=bool)
Run Code Online (Sandbox Code Playgroud)

方法#3

要完成设置,如果我们正在寻找性能,请使用初始化,然后advanced-indexing-

out = np.zeros(a.shape, dtype=bool)
out[np.arange(len(a)), a.argmax(axis=1)] = 1
Run Code Online (Sandbox Code Playgroud)

  • 我最终使用了方法 #2 的变体。当 axis=1 数组包含多个对应于最大值的项目时,它也适用。argmax 仅将第一次出现返回为 True,这是最适合我的问题的行为 (2认同)