我有两个不同长度的二维张量,两者都是同一原始二维张量的不同子集,我想找到所有匹配的“行”,
例如
A = [[1,2,3],[4,5,6],[7,8,9],[3,3,3]
B = [[1,2,3],[7,8,9],[4,4,4]]
torch.2dintersect(A,B) -> [0,2] (the indecies of A that B also have)
Run Code Online (Sandbox Code Playgroud)
我只看到 numpy 解决方案,使用 dtype 作为字典,并且不适用于 pytorch。
这是我在 numpy 中的做法
arr1 = edge_index_dense.numpy().view(np.int32)
arr2 = edge_index2_dense.numpy().view(np.int32)
arr1_view = arr1.view([('', arr1.dtype)] * arr1.shape[1])
arr2_view = arr2.view([('', arr2.dtype)] * arr2.shape[1])
intersected = np.intersect1d(arr1_view, arr2_view, return_indices=True)
Run Code Online (Sandbox Code Playgroud)
这个答案是在OP用其他限制更新问题之前发布的,这些限制大大改变了问题。
TL;DR你可以这样做:
torch.where((A == B).all(dim=1))[0]
Run Code Online (Sandbox Code Playgroud)
首先,假设您有:
import torch
A = torch.Tensor([[1,2,3],[4,5,6],[7,8,9]])
B = torch.Tensor([[1,2,3],[4,4,4],[7,8,9]])
Run Code Online (Sandbox Code Playgroud)
我们可以检查A == B返回:
>>> A == B
tensor([[ True, True, True],
[ True, False, False],
[ True, True, True]])
Run Code Online (Sandbox Code Playgroud)
所以,我们想要的是:它们所在的行True。为此,我们可以使用该.all()操作并指定感兴趣的维度,在我们的例子中1:
>>> (A == B).all(dim=1)
tensor([ True, False, True])
Run Code Online (Sandbox Code Playgroud)
您真正想知道的是Trues 在哪里。为此,我们可以获得torch.where()函数的第一个输出:
>>> torch.where((A == B).all(dim=1))[0]
tensor([0, 2])
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
5266 次 |
| 最近记录: |