Numpy:将每行的最大值更改为1,将所有其他数字更改为0

Mik*_*and 21 python numpy

我正在尝试实现一个numpy函数,它将2D数组的每一行中的max替换为1,并将所有其他数字替换为0:

>>> a = np.array([[0, 1],
...               [2, 3],
...               [4, 5],
...               [6, 7],
...               [9, 8]])
>>> b = some_function(a)
>>> b
[[0. 1.]
 [0. 1.]
 [0. 1.]
 [0. 1.]
 [1. 0.]]
Run Code Online (Sandbox Code Playgroud)

到目前为止我尝试过的

def some_function(x):
    a = np.zeros(x.shape)
    a[:,np.argmax(x, axis=1)] = 1
    return a

>>> b = some_function(a)
>>> b
[[1. 1.]
 [1. 1.]
 [1. 1.]
 [1. 1.]
 [1. 1.]]
Run Code Online (Sandbox Code Playgroud)

DSM*_*DSM 22

方法#1,调整你的:

>>> a = np.array([[0, 1], [2, 3], [4, 5], [6, 7], [9, 8]])
>>> b = np.zeros_like(a)
>>> b[np.arange(len(a)), a.argmax(1)] = 1
>>> b
array([[0, 1],
       [0, 1],
       [0, 1],
       [0, 1],
       [1, 0]])
Run Code Online (Sandbox Code Playgroud)

[实际上,range工作得很好; 我写出arange了习惯.]

方法#2,使用max而不是argmax处理多个元素达到最大值的情况:

>>> a = np.array([[0, 1], [2, 2], [4, 3]])
>>> (a == a.max(axis=1)[:,None]).astype(int)
array([[0, 1],
       [1, 1],
       [1, 0]])
Run Code Online (Sandbox Code Playgroud)

  • 非常感谢.关于方法#1的快速问题:为什么行上的```slice语法与提供带有行索引的数组本身不一样? (2认同)

Cyc*_*one 6

我更喜欢像这样使用 numpy.where:

a[np.where(a==np.max(a))] = 1
Run Code Online (Sandbox Code Playgroud)


Ale*_*ina 5

a==np.max(a)将来会引发错误,因此这里有一个调整后的版本,将继续正确广播。

我知道这个问题很古老,但我认为我有一个不错的解决方案,与其他解决方案有点不同。

# get max by row and convert from (n, ) -> (n, 1) which will broadcast
row_maxes = a.max(axis=1).reshape(-1, 1)
np.where(a == row_maxes, 1, 0)
np.where(a == row_maxes).astype(int)
Run Code Online (Sandbox Code Playgroud)

如果需要进行更新,您可以这样做

a[:] = np.where(a == row_maxes, 1, 0)
Run Code Online (Sandbox Code Playgroud)

  • 您可以执行“a.max(axis=1, keepdim=True)”,而不是“a.max(axis=1).reshape(-1, 1)”。 (4认同)
  • @a_guest 那里有错字。这是“keepdims” (2认同)