Nag*_*S N 9 python vectorization pytorch
我有两个 PyTorch 张量
mask = torch.ones(1024, 64, dtype=torch.float32)
indices = torch.randint(0, 64, (1024, ))
Run Code Online (Sandbox Code Playgroud)
对于i中的每一行,我想将 的元素mask指定的索引后面的所有元素设置为零。例如,如果 的第一个元素是,那么我想设置。不使用for循环是否可以实现这一点?iindicesindices50mask[0, 50:]=0
for循环的解决方案:
for i in range(mask.shape[0]):
mask[i, indices[i]:] = 0
Run Code Online (Sandbox Code Playgroud)
小智 6
您可以首先生成大小为 (1024x64) 的张量,其中每行都有从 0 到 63 排列的数字。然后使用重塑为 (1024x1) 的索引应用逻辑运算
mask = torch.ones(1024, 64, dtype=torch.float32)
indices = torch.randint(0, 64, (1024, 1)) # Note the dimensions
mask[torch.arange(0, 64, dtype=torch.float32).repeat(1024,1) >= indices] = 0
Run Code Online (Sandbox Code Playgroud)