Dou*_*tti 5 python pytorch tensor
例如,我想获取 tensor 中值为 0 和 2 的元素的索引a。这些值(0 和 2)存储在 tensor 中b。我已经设计了一种 pythonic 方法来这样做(如下所示),但我认为列表推导式没有被优化为在 GPU 上运行,或者也许有一种我不知道的更多 PyTorchy 方法来做到这一点。
import torch
a = torch.tensor([0, 1, 0, 1, 1, 0, 2])
b = torch.tensor([0, 2])
torch.tensor([x in b for x in a]).nonzero()
>>>> tensor([[0],
[2],
[5],
[6]])
Run Code Online (Sandbox Code Playgroud)
任何其他建议或者这是一种可以接受的方式?
这是一种更有效的方法(如 jodag 在评论中发布的链接中所建议的...):
(a[..., None] == b).any(-1).nonzero()
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
829 次 |
| 最近记录: |