PyTorch 张量索引或条件选择?

ech*_*Lee 2 python pytorch

    for c in range(self.n_class):
        target[c][label == c] = 1
Run Code Online (Sandbox Code Playgroud)

self.n_class 是 32。目标是 32 x 1024 x 2048 张量。

我知道 target[c] 选择 1 x 1024 x 2048 中的每一个。但我不明白 [label == c]。

因为根据经验,正方形 [] 中应包含整数。

有人能解释一下第二个方块的作用以及它的意义吗?

nai*_*rbv 7

PyTorch 支持“高级索引”。它实现了接受运算符张量参数的能力[]

运算符的结果==是一个布尔掩码。操作[]员使用该掩码来选择元素。下面的这个例子可能有助于澄清:

>>> x=torch.arange(0,10)
>>> x
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
>>> x < 5
tensor([ True,  True,  True,  True,  True, False, False, False, False, False])
>>> x[x < 5]
tensor([0, 1, 2, 3, 4])
>>> x[x > 5]
tensor([6, 7, 8, 9])
>>>
Run Code Online (Sandbox Code Playgroud)

一些一般文档: https://www.pythonlikeyoumeanit.com/Module3_IntroducingNumpy/BasicIndexing.html

numpy 中的高级索引: https://numpy.org/doc/1.18/reference/arrays.indexing.html