如何获得 pytorch 张量的独特元素及其首次出现的索引?

oji*_*son 4 python indexing reduction pytorch

假设 2*X(总是 2 行)pytorch 张量:

A = tensor([[ 1.,  2.,  2.,  3.,  3.,  3.,  4.,  4.,  4.],
            [43., 33., 43., 76., 33., 76., 55., 55., 55.]])
Run Code Online (Sandbox Code Playgroud)

torch.unique(A, dim=1)将返回:

tensor([[ 1.,  2.,  2.,  3.,  3.,  4.],
        [43., 33., 43., 33., 76., 55.]])
Run Code Online (Sandbox Code Playgroud)

但我还需要每个独特元素的索引,它们首先出现在原始输入中。在这种情况下,索引应该是这样的:

tensor([0, 1, 2, 3, 4, 6])

# Explanation
# A = tensor([[ 1.,  2.,  2.,  3.,  3.,  3.,  4.,  4.,  4.],
#             [43., 33., 43., 76., 33., 76., 55., 55., 55.]])
#              (0)  (1)  (2)  (3)  (4)       (6) 
Run Code Online (Sandbox Code Playgroud)

这对我来说很复杂,因为第二行张量A可能没有很好地排序:

A = tensor([[ 1.,  2.,  2.,  3.,  3.,  3.,  4.,  4.,  4.],
            [43., 33., 43., 76., 33., 76., 55., 55., 55.]])
                             ^         ^
Run Code Online (Sandbox Code Playgroud)

有没有一种简单有效的方法来获得所需的指数?

PS 张量的第一行始终按升序排列可能很有用。

小智 5

获得此类指数的一种可能方法:

unique, idx, counts = torch.unique(A, dim=1, sorted=True, return_inverse=True, return_counts=True)
_, ind_sorted = torch.sort(idx, stable=True)
cum_sum = counts.cumsum(0)
cum_sum = torch.cat((torch.tensor([0]), cum_sum[:-1]))
first_indicies = ind_sorted[cum_sum]
Run Code Online (Sandbox Code Playgroud)

对于A上面代码片段中的张量:

print(first_indicies)
# tensor([0, 1, 2, 4, 3, 6])
Run Code Online (Sandbox Code Playgroud)

请注意,unique在这种情况下等于:

 tensor([[ 1.,  2.,  2.,  3.,  3.,  4.],
         [43., 33., 43., 33., 76., 55.]])
Run Code Online (Sandbox Code Playgroud)