numpy:如何使用argmax结果获取实际最大值?

Max*_*axB 5 python numpy

假设我有一个3D数组:

>>> a
array([[[7, 0],
        [3, 6]],

       [[2, 4],
        [5, 1]]])
Run Code Online (Sandbox Code Playgroud)

我可以得到它argmax一起axis=1使用

>>> m = np.argmax(a, axis=1)
>>> m
array([[0, 1],
       [1, 0]])
Run Code Online (Sandbox Code Playgroud)

我如何才能m用作in的索引a,以便结果等同于简单地使用max

>>> a.max(axis=1)
array([[7, 6],
       [5, 4]])
Run Code Online (Sandbox Code Playgroud)

(这m适用于相同形状的其他数组时)

Psi*_*dom 5

You can do this with advanced indexing and numpy broadcasting:

m = np.argmax(a, axis=1)
a[np.arange(a.shape[0])[:,None], m, np.arange(a.shape[2])]

#array([[7, 6],
#       [5, 4]])
Run Code Online (Sandbox Code Playgroud)
m = np.argmax(a, axis=1)
Run Code Online (Sandbox Code Playgroud)

Create arrays of 1st, 2nd and 3rd dimensions indices:

ind1, ind2, ind3 = np.arange(a.shape[0])[:,None], m, np.arange(a.shape[2])
?
Run Code Online (Sandbox Code Playgroud)

Because of the dimension mismatch, the three arrays will broadcast, which result in each to be as follows:

for x in np.broadcast_arrays(ind1, ind2, ind3):
    print(x, '\n')

#[[0 0]
# [1 1]] 

#[[0 1]
# [1 0]] 

#[[0 1]
# [0 1]] 
Run Code Online (Sandbox Code Playgroud)

And since all indices are integer arrays, it triggers advanced indexing, so elements with indices (0, 0, 0), (0, 1, 1), (1, 1, 0), (1, 0, 1) are picked up, i.e. one element from each array combined as the index.