PyTorch - 在 torch.sort 之后取回原始张量顺序的更好方法

Bob*_*bby 8 python pytorch

我想在torch.sort对排序张量进行操作和其他一些修改后取回原始张量顺序,以便不再对张量进行排序。最好用一个例子来解释这一点:

x = torch.tensor([30., 40., 20.])
ordered, indices = torch.sort(x)
# ordered is [20., 30., 40.]
# indices is [2, 0, 1]
ordered = torch.tanh(ordered) # it doesn't matter what operation is
final = original_order(ordered, indices) 
# final must be equal to torch.tanh(x)
Run Code Online (Sandbox Code Playgroud)

我以这种方式实现了该功能:

def original_order(ordered, indices):
    z = torch.empty_like(ordered)
    for i in range(ordered.size(0)):
        z[indices[i]] = ordered[i]
    return z
Run Code Online (Sandbox Code Playgroud)

Is there a better way to do this? In particular, it is possible to avoid the loop and compute the operation more efficiently?

In my case I have a tensor of size torch.Size([B, N]) and I sort each of the B rows separately with a single call of torch.sort. So, I have to call original_order B times with another loop.

Any, more pytorch-ic, ideas?

EDIT 1 - Get rid of inner loop

I solved part of the problem by simply indexing z with indices in this way:

def original_order(ordered, indices):
    z = torch.empty_like(ordered)
    z[indices] = ordered
    return z
Run Code Online (Sandbox Code Playgroud)

Now, I just have to understand how to avoid the outer loop on B dimension.

EDIT 2 - Get rid of outer loop

def original_order(ordered, indices, batch_size):
    # produce a vector to shift indices by lenght of the vector 
    # times the batch position
    add = torch.linspace(0, batch_size-1, batch_size) * indices.size(1)


    indices = indices + add.long().view(-1,1)

    # reduce tensor to single dimension. 
    # Now the indices take in consideration the new length
    long_ordered = ordered.view(-1)
    long_indices = indices.view(-1)

    # we are in the previous case with one dimensional vector
    z = torch.zeros_like(long_ordered).float()
    z[long_indices] = long_ordered

    # reshape to get back to the correct dimension
    return z.view(batch_size, -1)
Run Code Online (Sandbox Code Playgroud)

qmk*_*qmk 5

def original_order(ordered, indices):
    return ordered.gather(1, indices.argsort(1))
Run Code Online (Sandbox Code Playgroud)

例子

original = torch.tensor([
    [20, 22, 24, 21],
    [12, 14, 10, 11],
    [34, 31, 30, 32]])
sorted, index = original.sort()
unsorted = sorted.gather(1, index.argsort(1))
assert(torch.all(original == unsorted))
Run Code Online (Sandbox Code Playgroud)

为什么有效

为简单起见,想象一下t = [30, 10, 20],省略张量表示法。

t.sort()给我们排序的张量s = [10, 20, 30],以及i = [1, 2, 0]免费的排序索引。i实际上是 的输出t.argsort()

i告诉我们如何从ts。“要排序ts,从t“中取出元素 1,然后是 2,然后是 0 。Argsortingi为我们提供了另一个排序索引j = [2, 0, 1],它告诉我们如何从i自然数到规范序列[0, 1, 2],实际上是反向排序。另一种看待它的方式是j告诉我们如何从st。“要排序st,从s“中取元素 2,然后是 0,然后是 1 。对排序索引进行 Argsorting 为我们提供了它的“反向索引”,反之亦然。

现在我们有了逆索引,我们将它转​​储到torch.gather()正确的 中dim,然后对张量进行排序。

来源

torch.gather torch.argsort

我在研究这个问题时找不到这个确切的解决方案,所以我认为这是一个原始答案。