use*_*788 6 python multidimensional-array deep-learning pytorch tensor
我试图索引多维张量中沿最后一个维度的最大元素。例如说我有一个张量
A = torch.randn((5, 2, 3))
_, idx = torch.max(A, dim=2)
Run Code Online (Sandbox Code Playgroud)
idx在这里存储最大索引,看起来可能像
>>>> A
tensor([[[ 1.0503, 0.4448, 1.8663],
[ 0.8627, 0.0685, 1.4241]],
[[ 1.2924, 0.2456, 0.1764],
[ 1.3777, 0.9401, 1.4637]],
[[ 0.5235, 0.4550, 0.2476],
[ 0.7823, 0.3004, 0.7792]],
[[ 1.9384, 0.3291, 0.7914],
[ 0.5211, 0.1320, 0.6330]],
[[ 0.3292, 0.9086, 0.0078],
[ 1.3612, 0.0610, 0.4023]]])
>>>> idx
tensor([[ 2, 2],
[ 0, 2],
[ 0, 0],
[ 0, 2],
[ 1, 0]])
Run Code Online (Sandbox Code Playgroud)
我希望能够访问这些索引并基于它们分配给另一个张量。意思是我想做
B = torch.new_zeros(A.size())
B[idx] = A[idx]
Run Code Online (Sandbox Code Playgroud)
其中B在所有地方都是0,除了A在最后一个维度上最大。那是B应该存储
>>>>B
tensor([[[ 0, 0, 1.8663],
[ 0, 0, 1.4241]],
[[ 1.2924, 0, 0],
[ 0, 0, 1.4637]],
[[ 0.5235, 0, 0],
[ 0.7823, 0, 0]],
[[ 1.9384, 0, 0],
[ 0, 0, 0.6330]],
[[ 0, 0.9086, 0],
[ 1.3612, 0, 0]]])
Run Code Online (Sandbox Code Playgroud)
事实证明,这比我预期的要困难得多,因为idx无法正确索引数组A。到目前为止,我一直无法找到使用idx索引A的向量化解决方案。
有一个好的矢量化方法来做到这一点吗?
您可以使用torch.meshgrid创建索引元组:
>>> index_tuple = torch.meshgrid([torch.arange(x) for x in A.size()[:-1]]) + (idx,)
>>> B = torch.zeros_like(A)
>>> B[index_tuple] = A[index_tuple]
Run Code Online (Sandbox Code Playgroud)
请注意,您还可以meshgrid通过以下方式进行模仿(针对 3D 的特定情况):
>>> index_tuple = (
... torch.arange(A.size(0))[:, None],
... torch.arange(A.size(1))[None, :],
... idx
... )
Run Code Online (Sandbox Code Playgroud)
更多解释:
我们将有这样的索引:
In [173]: idx
Out[173]:
tensor([[2, 1],
[2, 0],
[2, 1],
[2, 2],
[2, 2]])
Run Code Online (Sandbox Code Playgroud)
由此,我们想要得到三个索引(因为我们的张量是 3D,所以我们需要三个数字来检索每个元素)。基本上我们想在前两个维度构建一个网格,如下所示。(这就是我们使用网格的原因)。
In [174]: A[0, 0, 2], A[0, 1, 1]
Out[174]: (tensor(0.6288), tensor(-0.3070))
In [175]: A[1, 0, 2], A[1, 1, 0]
Out[175]: (tensor(1.7085), tensor(0.7818))
In [176]: A[2, 0, 2], A[2, 1, 1]
Out[176]: (tensor(0.4823), tensor(1.1199))
In [177]: A[3, 0, 2], A[3, 1, 2]
Out[177]: (tensor(1.6903), tensor(1.0800))
In [178]: A[4, 0, 2], A[4, 1, 2]
Out[178]: (tensor(0.9138), tensor(0.1779))
Run Code Online (Sandbox Code Playgroud)
在上面的 5 行中,索引中的前两个数字基本上是我们使用 meshgrid 构建的网格,第三个数字来自idx.
即前两个数字形成一个网格。
(0, 0) (0, 1)
(1, 0) (1, 1)
(2, 0) (2, 1)
(3, 0) (3, 1)
(4, 0) (4, 1)
Run Code Online (Sandbox Code Playgroud)
一个丑陋的解决办法是创建一个二进制掩码idx并用它来索引数组。基本代码如下所示:
import torch
torch.manual_seed(0)
A = torch.randn((5, 2, 3))
_, idx = torch.max(A, dim=2)
mask = torch.arange(A.size(2)).reshape(1, 1, -1) == idx.unsqueeze(2)
B = torch.zeros_like(A)
B[mask] = A[mask]
print(A)
print(B)
Run Code Online (Sandbox Code Playgroud)
诀窍是torch.arange(A.size(2))枚举 中 的可能值,idx并且mask在它们等于 的地方不为零idx。评论:
torch.max,则可以改用torch.argmax。torch.nn.functional.max_pool3dsize 的内核进行重新发明(1, 1, 3)。torch.where按此处所示使用。我希望有人能提出一个更干净的解决方案(避免数组的中间分配mask),可能会使用torch.index_select,但我现在无法让它工作。