oji*_*son 4 python indexing reduction pytorch
假设 2*X(总是 2 行)pytorch 张量:
A = tensor([[ 1., 2., 2., 3., 3., 3., 4., 4., 4.],
[43., 33., 43., 76., 33., 76., 55., 55., 55.]])
Run Code Online (Sandbox Code Playgroud)
torch.unique(A, dim=1)将返回:
tensor([[ 1., 2., 2., 3., 3., 4.],
[43., 33., 43., 33., 76., 55.]])
Run Code Online (Sandbox Code Playgroud)
但我还需要每个独特元素的索引,它们首先出现在原始输入中。在这种情况下,索引应该是这样的:
tensor([0, 1, 2, 3, 4, 6])
# Explanation
# A = tensor([[ 1., 2., 2., 3., 3., 3., 4., 4., 4.],
# [43., 33., 43., 76., 33., 76., 55., 55., 55.]])
# (0) (1) (2) (3) (4) (6)
Run Code Online (Sandbox Code Playgroud)
这对我来说很复杂,因为第二行张量A可能没有很好地排序:
A = tensor([[ 1., 2., 2., 3., 3., 3., 4., 4., 4.],
[43., 33., 43., 76., 33., 76., 55., 55., 55.]])
^ ^
Run Code Online (Sandbox Code Playgroud)
有没有一种简单有效的方法来获得所需的指数?
PS 张量的第一行始终按升序排列可能很有用。
小智 5
获得此类指数的一种可能方法:
unique, idx, counts = torch.unique(A, dim=1, sorted=True, return_inverse=True, return_counts=True)
_, ind_sorted = torch.sort(idx, stable=True)
cum_sum = counts.cumsum(0)
cum_sum = torch.cat((torch.tensor([0]), cum_sum[:-1]))
first_indicies = ind_sorted[cum_sum]
Run Code Online (Sandbox Code Playgroud)
对于A上面代码片段中的张量:
print(first_indicies)
# tensor([0, 1, 2, 4, 3, 6])
Run Code Online (Sandbox Code Playgroud)
请注意,unique在这种情况下等于:
tensor([[ 1., 2., 2., 3., 3., 4.],
[43., 33., 43., 33., 76., 55.]])
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
2392 次 |
| 最近记录: |