use*_*043 4 python pytorch tensor
我有两个这样的张量:
1st tensor
[[0,0],[0,1],[0,2],[1,3],[1,4],[2,1],[2,4]]
2nd tensor
[[0,1],[0,2],[1,4],[2,4]]
Run Code Online (Sandbox Code Playgroud)
我希望结果张量是这样的:
[[0,0],[1,3],[2,1]] # differences between 1st tensor and 2nd tensor
Run Code Online (Sandbox Code Playgroud)
我尝试使用 set、list、torch.where 等,但找不到任何好方法来实现这一点。有没有办法在两个不同大小的张量之间获得不同的行?(需要高效)
您可以执行成对比较以查看第一个张量的哪些元素存在于第二个向量中。
a = torch.as_tensor([[0,0],[0,1],[0,2],[1,3],[1,4],[2,1],[2,4]])
b = torch.as_tensor([[0,1],[0,2],[1,4],[2,4]])
# Expand a to (7, 1, 2) to broadcast to all b
a_exp = a.unsqueeze(1)
# c: (7, 4, 2)
c = a_exp == b
# Since we want to know that all components of the vector are equal, we reduce over the last fim
# c: (7, 4)
c = c.all(-1)
print(c)
# Out: Each row i compares the ith element of a against all elements in b
# Therefore, if all row is false means that the a element is not present in b
tensor([[False, False, False, False],
[ True, False, False, False],
[False, True, False, False],
[False, False, False, False],
[False, False, True, False],
[False, False, False, False],
[False, False, False, True]])
non_repeat_mask = ~c.any(-1)
# Apply the mask to a
print(a[non_repeat_mask])
tensor([[0, 0],
[1, 3],
[2, 1]])
Run Code Online (Sandbox Code Playgroud)
如果你觉得很酷,你可以做一个班轮:)
a[~a.unsqueeze(1).eq(b).all(-1).any(-1)]
Run Code Online (Sandbox Code Playgroud)