如何在给定轴中的某个索引之后将 pytorch 张量的所有元素设置为零,其中索引由另一个 pytorch 张量给出?

Nag*_*S N 9 python vectorization pytorch

我有两个 PyTorch 张量

mask = torch.ones(1024, 64, dtype=torch.float32)
indices = torch.randint(0, 64, (1024, ))
Run Code Online (Sandbox Code Playgroud)

对于i中的每一行,我想将 的元素mask指定的索引后面的所有元素设置为零。例如,如果 的第一个元素是,那么我想设置。不使用for循环是否可以实现这一点?iindicesindices50mask[0, 50:]=0

for循环的解决方案:

for i in range(mask.shape[0]):
    mask[i, indices[i]:] = 0
Run Code Online (Sandbox Code Playgroud)

小智 6

您可以首先生成大小为 (1024x64) 的张量,其中每行都有从 0 到 63 排列的数字。然后使用重塑为 (1024x1) 的索引应用逻辑运算

mask = torch.ones(1024, 64, dtype=torch.float32)
indices = torch.randint(0, 64, (1024, 1))    # Note the dimensions

mask[torch.arange(0, 64, dtype=torch.float32).repeat(1024,1) >= indices] = 0
Run Code Online (Sandbox Code Playgroud)

  • 惊人的!我能够进一步简单地将其简化为“mask = (torch.arange(0, 64, dtype=torch.float32).repeat(1024,1) <索引).float()” (2认同)