获取张量 a 中存在于张量 b 中的元素的索引

Dou*_*tti 5 python pytorch tensor

例如,我想获取 tensor 中值为 0 和 2 的元素的索引a。这些值(0 和 2)存储在 tensor 中b。我已经设计了一种 pythonic 方法来这样做(如下所示),但我认为列表推导式没有被优化为在 GPU 上运行,或者也许有一种我不知道的更多 PyTorchy 方法来做到这一点。

import torch
a = torch.tensor([0, 1, 0, 1, 1, 0, 2])
b = torch.tensor([0, 2])
torch.tensor([x in b for x in a]).nonzero()

>>>> tensor([[0],
             [2],
             [5],
             [6]])
Run Code Online (Sandbox Code Playgroud)

任何其他建议或者这是一种可以接受的方式?

And*_*dyK 9

这是一种更有效的方法(如 jodag 在评论中发布的链接中所建议的...):

(a[..., None] == b).any(-1).nonzero()
Run Code Online (Sandbox Code Playgroud)