我有以下代码:
a = torch.randint(0,10,[3,3,3,3])
b = torch.LongTensor([1,1,1,1])
Run Code Online (Sandbox Code Playgroud)
我有一个多维索引,b
并希望用它来选择一个单元格a
.如果b不是张量,我可以这样做:
a[1,1,1,1]
Run Code Online (Sandbox Code Playgroud)
哪个返回正确的单元格,但是:
a[b]
Run Code Online (Sandbox Code Playgroud)
不起作用,因为它只选择了a[1]
四次.
我怎样才能做到这一点?谢谢
您可以使用分割b
成4个chunk
,然后使用分块b
索引所需的特定元素:
>> a = torch.arange(3*3*3*3).view(3,3,3,3)
>> b = torch.LongTensor([[1,1,1,1], [2,2,2,2], [0, 0, 0, 0]]).t()
>> a[b.chunk(chunks=4, dim=0)] # here's the trick!
Out[24]: tensor([[40, 80, 0]])
Run Code Online (Sandbox Code Playgroud)
这样做的好处是可以轻松地将其推广到的任何尺寸a
,您只需要使卡盘的数量等于的尺寸即可a
。
一个更优雅(更简单)的解决方案可能是将其简单地转换b
为元组:
a[tuple(b)]
Out[10]: tensor(5.)
Run Code Online (Sandbox Code Playgroud)
我很想知道它如何与“常规” numpy一起工作,因此在这里找到了一篇相关的文章对此进行了很好的解释。
归档时间: |
|
查看次数: |
5771 次 |
最近记录: |