torch.unique() 中的参数“dim”如何工作?

Sea*_*Lee 4 pytorch

我试图提取矩阵每一行中的唯一值并将它们返回到同一个矩阵中(重复值设置为 0)例如,我想转换

torch.Tensor(([1, 2, 3, 4, 3, 3, 4],
              [1, 6, 3, 5, 3, 5, 4]])
Run Code Online (Sandbox Code Playgroud)

torch.Tensor(([1, 2, 3, 4, 0, 0, 0],
              [1, 6, 3, 5, 0, 0, 4]])
Run Code Online (Sandbox Code Playgroud)

或者

torch.Tensor(([1, 2, 3, 4, 0, 0, 0],
              [1, 6, 3, 5, 4, 0, 0]])
Run Code Online (Sandbox Code Playgroud)

即行中的顺序并不重要。我尝试过使用pytorch.unique(),并且在文档中提到可以使用参数指定采用唯一值的维度dim。然而,它似乎不适用于这种情况。

我试过了:

output= torch.unique(torch.Tensor([[4,2,52,2,2],[5,2,6,6,5]]), dim = 1)

output
Run Code Online (Sandbox Code Playgroud)

这使

tensor([[ 2.,  2.,  2.,  4., 52.],
        [ 2.,  5.,  6.,  5.,  6.]])
Run Code Online (Sandbox Code Playgroud)

有人对此有特别的解决办法吗?如果可能的话,我会尽量避免 for 循环。

You*_*ang 6

第一次使用torch.unique时我很困惑。经过一些实验后,我终于弄清楚了这个dim论点是如何运作的。torch.unique的文档说:

counts (Tensor):(可选)如果 return_counts 为 True,则会有一个额外的返回张量(与 output 或 output.size(dim) 形状相同,如果指定了 dim),表示每个唯一值或张量的出现次数。

例如,如果您的输入张量是大小为 的 3D 张量n x m x kdim=2unique将分别作用于k大小为 的矩阵n x m。换句话说,它将把除暗淡之外的所有维度视为2单个张量。

  • 我相信这是正确的答案,而接受的答案是错误的。 (3认同)

Rex*_*Low 5

人们必须承认,unique如果没有给出适当的示例和解释,该函数有时会非常令人困惑。

dim参数指定要应用到矩阵张量的哪个维度。

例如,在二维矩阵中,dim=0将让运算垂直执行,其中dim=1意味着水平执行。

例如,让我们考虑一个 4x4 矩阵dim=1。正如您从下面的代码中看到的,该unique操作是逐行应用的。

11您注意到第一行和最后一行中重复出现了该数字。Numpy 和 Torch 这样做是为了保留最终矩阵的形状。

但是,如果您没有指定任何维度,torch 会自动展平您的矩阵,然后应用unique到它,您将得到一个包含唯一数据的一维数组。

import torch

m = torch.Tensor([
    [11, 11, 12,11], 
    [13, 11, 12,11], 
    [16, 11, 12, 11],  
    [11, 11, 12, 11]
])

output, indices = torch.unique(m, sorted=True, return_inverse=True, dim=1)
print("Ori \n{}".format(m.numpy()))
print("Sorted \n{}".format(output.numpy()))
print("Indices \n{}".format(indices.numpy()))

# without specifying dimension
output, indices = torch.unique(m, sorted=True, return_inverse=True)
print("Sorted (no dim) \n{}".format(output.numpy()))
Run Code Online (Sandbox Code Playgroud)

结果(暗淡=1)

Ori
[[11. 11. 12. 11.]
 [13. 11. 12. 11.]
 [16. 11. 12. 11.]
 [11. 11. 12. 11.]]
Sorted
[[11. 11. 12.]
 [11. 13. 12.]
 [11. 16. 12.]
 [11. 11. 12.]]
Indices
[1 0 2 0]
Run Code Online (Sandbox Code Playgroud)

结果(无维度)

Sorted (no dim)
[11. 12. 13. 16.]
Run Code Online (Sandbox Code Playgroud)