对于给定条件,获取 2D 张量 A 中的值的索引,使用这些索引来索引 3D 张量 B

Bra*_*roy 1 python multidimensional-array pytorch tensor tensor-indexing

对于给定的二维张量,我想检索值为 的所有索引1。我期望能够简单地使用torch.nonzero(a == 1).squeeze(),这将返回tensor([1, 3, 2])。然而,相反,torch.nonzero(a == 1)返回一个 2D 张量(没关系),每行有两个值(这不是我所期望的)。然后,应使用返回的索引来索引 3D 张量的第二个维度(索引 1),再次返回 2D 张量。

import torch

a = torch.Tensor([[12, 1, 0, 0],
                  [4, 9, 21, 1],
                  [10, 2, 1, 0]])

b = torch.rand(3, 4, 8)

print('a_size', a.size())
# a_size torch.Size([3, 4])
print('b_size', b.size())
# b_size torch.Size([3, 4, 8])

idxs = torch.nonzero(a == 1)
print('idxs_size', idxs.size())
# idxs_size torch.Size([3, 2])

print(b.gather(1, idxs))
Run Code Online (Sandbox Code Playgroud)

显然,这不起作用,导致运行时错误:

RuntimeError:无效参数 4:索引张量必须与 C:\w\1\s\windows\pytorch\aten\src\TH/generic/THTensorEvenMoreMath.cpp:453 处的输入张量具有相同的维度

看来idxs不是我想象的那样,也不能按照我想的方式使用。idxs

tensor([[0, 1],
        [1, 3],
        [2, 2]])
Run Code Online (Sandbox Code Playgroud)

但通读文档后,我不明白为什么我还要返回结果张量中的行索引。现在,我知道我可以通过切片获得正确的 idx idxs[:, 1],但我仍然无法使用这些值作为 3D 张量的索引,因为会引发与之前相同的错误。是否可以使用索引的一维张量来选择给定维度上的项目?

kma*_*o23 5

您可以简单地对它们进行切片并将其作为索引传递,如下所示:

In [193]: idxs = torch.nonzero(a == 1)     
In [194]: c = b[idxs[:, 0], idxs[:, 1]]  

In [195]: c   
Out[195]: 
tensor([[0.3411, 0.3944, 0.8108, 0.3986, 0.3917, 0.1176, 0.6252, 0.4885],
        [0.5698, 0.3140, 0.6525, 0.7724, 0.3751, 0.3376, 0.5425, 0.1062],
        [0.7780, 0.4572, 0.5645, 0.5759, 0.5957, 0.2750, 0.6429, 0.1029]])
Run Code Online (Sandbox Code Playgroud)

或者,一种更简单且我更喜欢的方法是仅使用张量torch.where(),然后直接索引到张量中,b如下所示:

In [196]: b[torch.where(a == 1)]  
Out[196]: 
tensor([[0.3411, 0.3944, 0.8108, 0.3986, 0.3917, 0.1176, 0.6252, 0.4885],
        [0.5698, 0.3140, 0.6525, 0.7724, 0.3751, 0.3376, 0.5425, 0.1062],
        [0.7780, 0.4572, 0.5645, 0.5759, 0.5957, 0.2750, 0.6429, 0.1029]])
Run Code Online (Sandbox Code Playgroud)

关于上述使用方法的更多解释torch.where():它基于高级索引的概念。也就是说,当我们使用序列对象的元组(例如张量的元组、列表的元组、元组的元组等)对张量进行索引时。

# some input tensor
In [207]: a  
Out[207]: 
tensor([[12.,  1.,  0.,  0.],
        [ 4.,  9., 21.,  1.],
        [10.,  2.,  1.,  0.]])
Run Code Online (Sandbox Code Playgroud)

对于基本切片,我们需要一个整数索引元组:

   In [212]: a[(1, 2)] 
   Out[212]: tensor(21.)
Run Code Online (Sandbox Code Playgroud)

为了使用高级索引实现相同的目的,我们需要一个序列对象的元组:

# adv. indexing using a tuple of lists
In [213]: a[([1,], [2,])] 
Out[213]: tensor([21.])

# adv. indexing using a tuple of tuples
In [215]: a[((1,), (2,))]  
Out[215]: tensor([21.])

# adv. indexing using a tuple of tensors
In [214]: a[(torch.tensor([1,]), torch.tensor([2,]))] 
Out[214]: tensor([21.])
Run Code Online (Sandbox Code Playgroud)

并且返回张量的维度始终比输入张量的维度小一维。