在PyTorch中使用张量索引多维张量

Chu*_*ows 8 pytorch tensor

我有以下代码:

a = torch.randint(0,10,[3,3,3,3])
b = torch.LongTensor([1,1,1,1])
Run Code Online (Sandbox Code Playgroud)

我有一个多维索引,b并希望用它来选择一个单元格a.如果b不是张量,我可以这样做:

a[1,1,1,1]
Run Code Online (Sandbox Code Playgroud)

哪个返回正确的单元格,但是:

a[b]
Run Code Online (Sandbox Code Playgroud)

不起作用,因为它只选择了a[1]四次.

我怎样才能做到这一点?谢谢

Sha*_*hai 6

您可以使用分割b成4个chunk,然后使用分块b索引所需的特定元素:

>> a = torch.arange(3*3*3*3).view(3,3,3,3)
>> b = torch.LongTensor([[1,1,1,1], [2,2,2,2], [0, 0, 0, 0]]).t()
>> a[b.chunk(chunks=4, dim=0)]   # here's the trick!
Out[24]: tensor([[40, 80,  0]])
Run Code Online (Sandbox Code Playgroud)

这样做的好处是可以轻松地将其推广到的任何尺寸a,您只需要使卡盘的数量等于的尺寸即可a


den*_*ger 5

一个更优雅(更简单)的解决方案可能是将其简单地转换b为元组:

a[tuple(b)]
Out[10]: tensor(5.)
Run Code Online (Sandbox Code Playgroud)

我很想知道它如何与“常规” numpy一起工作,因此在这里找到了一篇相关的文章对此进行了很好的解释。

  • 有什么办法可以使此解决方案与索引列表一起使用? (3认同)