在python列表中,我们可以使用list.index(somevalue)
.pytorch怎么做到这一点?
例如:
a=[1,2,3]
print(a.index(2))
Run Code Online (Sandbox Code Playgroud)
然后,1
将输出.pytorch张量如何在不将其转换为python列表的情况下执行此操作?
Man*_*nas 25
我认为没有直接转换list.index()
为pytorch函数.但是,您可以使用tensor==number
然后使用该nonzero()
功能获得类似的结果.例如:
t = torch.Tensor([1, 2, 3])
print ((t == 2).nonzero())
Run Code Online (Sandbox Code Playgroud)
这段代码返回
1
[torch.LongTensor尺寸1x1]
对于多维张量,您可以执行以下操作:
(tensor == target_value).nonzero(as_tuple=True)
Run Code Online (Sandbox Code Playgroud)
生成的张量的形状为number_of_matches x tensor_dimension
。例如,假设tensor
是一个3 x 4
张量(这意味着维度为 2),结果将是一个二维张量,其中包含行中匹配项的索引。
tensor = torch.Tensor([[1, 2, 2, 7], [3, 1, 2, 4], [3, 1, 9, 4]])
(tensor == 2).nonzero(as_tuple=False)
>>> tensor([[0, 1],
[0, 2],
[1, 2]])
Run Code Online (Sandbox Code Playgroud)
归档时间: |
|
查看次数: |
19799 次 |
最近记录: |