np.eye(n)[nparray] 是什么意思?

bho*_*ass 9 numpy

我正在浏览一些代码

y_enc = np.eye(21)[label]
Run Code Online (Sandbox Code Playgroud)

其中 label 是形状 (224, 224) 的 ndarray y_enc 是形状 (224, 224, 21) 的 ndarray

即使打印了形状,我也无法理解这句话。np.eye 应该生成一个维度为 21 x 21 的对角矩阵。 [label] 跟随它是什么意思?

R.A*_*nna 11

文档numpy.eye

返回一个二维数组,对角线上为 1,其他地方为 0。

例子:

>>np.eye(3)
array([[ 1.,  0.,  0.],
   [ 0.,  1.,  0.],
   [ 0.,  0.,  1.]])
>>> np.eye(3)[1]
array([ 0.,  1.,  0.])
Run Code Online (Sandbox Code Playgroud)

[label]是数组元素索引。因此,其中只有一个元素,它以数组形式返回给定的行数元素。

>>> np.eye(3)[1]
array([ 0.,  1.,  0.])
>>> np.eye(3)[2]
array([ 0.,  0.,  1.])
Run Code Online (Sandbox Code Playgroud)

因为它是2d数组,您还可以通过在上提供 2 个索引号来访问特定元素[row_index, column_index]

>>> np.eye(3)[2,1]
0.0
>>> np.eye(3)[2,2]
1.0
>>> np.eye(3)[1,1]
1.0
Run Code Online (Sandbox Code Playgroud)