我想在torch.sort对排序张量进行操作和其他一些修改后取回原始张量顺序,以便不再对张量进行排序。最好用一个例子来解释这一点:
x = torch.tensor([30., 40., 20.])
ordered, indices = torch.sort(x)
# ordered is [20., 30., 40.]
# indices is [2, 0, 1]
ordered = torch.tanh(ordered) # it doesn't matter what operation is
final = original_order(ordered, indices)
# final must be equal to torch.tanh(x)
Run Code Online (Sandbox Code Playgroud)
我以这种方式实现了该功能:
def original_order(ordered, indices):
z = torch.empty_like(ordered)
for i in range(ordered.size(0)):
z[indices[i]] = ordered[i]
return z
Run Code Online (Sandbox Code Playgroud)
Is there a better way to do this? In particular, it is possible to avoid the loop and compute the operation more efficiently?
In my case I have a tensor of size torch.Size([B, N]) and I sort each of the B rows separately with a single call of torch.sort. So, I have to call original_order B times with another loop.
Any, more pytorch-ic, ideas?
EDIT 1 - Get rid of inner loop
I solved part of the problem by simply indexing z with indices in this way:
def original_order(ordered, indices):
z = torch.empty_like(ordered)
z[indices] = ordered
return z
Run Code Online (Sandbox Code Playgroud)
Now, I just have to understand how to avoid the outer loop on B dimension.
EDIT 2 - Get rid of outer loop
def original_order(ordered, indices, batch_size):
# produce a vector to shift indices by lenght of the vector
# times the batch position
add = torch.linspace(0, batch_size-1, batch_size) * indices.size(1)
indices = indices + add.long().view(-1,1)
# reduce tensor to single dimension.
# Now the indices take in consideration the new length
long_ordered = ordered.view(-1)
long_indices = indices.view(-1)
# we are in the previous case with one dimensional vector
z = torch.zeros_like(long_ordered).float()
z[long_indices] = long_ordered
# reshape to get back to the correct dimension
return z.view(batch_size, -1)
Run Code Online (Sandbox Code Playgroud)
def original_order(ordered, indices):
return ordered.gather(1, indices.argsort(1))
Run Code Online (Sandbox Code Playgroud)
original = torch.tensor([
[20, 22, 24, 21],
[12, 14, 10, 11],
[34, 31, 30, 32]])
sorted, index = original.sort()
unsorted = sorted.gather(1, index.argsort(1))
assert(torch.all(original == unsorted))
Run Code Online (Sandbox Code Playgroud)
为简单起见,想象一下t = [30, 10, 20],省略张量表示法。
t.sort()给我们排序的张量s = [10, 20, 30],以及i = [1, 2, 0]免费的排序索引。i实际上是 的输出t.argsort()。
i告诉我们如何从t到s。“要排序t到s,从t“中取出元素 1,然后是 2,然后是 0 。Argsortingi为我们提供了另一个排序索引j = [2, 0, 1],它告诉我们如何从i自然数到规范序列[0, 1, 2],实际上是反向排序。另一种看待它的方式是j告诉我们如何从s到t。“要排序s到t,从s“中取元素 2,然后是 0,然后是 1 。对排序索引进行 Argsorting 为我们提供了它的“反向索引”,反之亦然。
现在我们有了逆索引,我们将它转储到torch.gather()正确的 中dim,然后对张量进行排序。
我在研究这个问题时找不到这个确切的解决方案,所以我认为这是一个原始答案。