Pytorch:将值从一个掩码分配给另一个掩码,由自身掩码

Bos*_*hie 4 python pytorch

我有一个掩码active,可以跟踪在循环过程中仍未终止的批次。它的维度是[batch_full,],它的真实条目显示当前步骤中仍需要使用哪些元素。循环过程生成另一个掩码 ,terminated它具有与掩码中的真实值一样多的元素active。现在,我想从 中取出值~terminated并将它们放回 中active,但要使用正确的索引。基本上我想做:

import torch

active = torch.ones([4,], dtype=torch.bool)
active[:2] = torch.tensor(False)

terminated = torch.tensor([True, False])

active[active] = ~terminated

print(active)  # expected [F, F, F, T]

Run Code Online (Sandbox Code Playgroud)

但是,我收到错误:

RuntimeError:不支持的操作:输入张量和写入张量的某些元素引用单个内存位置。请在执行操作之前克隆()张量。

如何才能有效地进行上述操作呢?

Bos*_*hie 5

有一些解决方案,我还将给出timeit在 2021 款 MacBook Pro 上通过 10k 次重复测量的速度。

最简单的解决方案,耗时 0.260s:

active[active.clone()] = ~terminated
Run Code Online (Sandbox Code Playgroud)

我们可以masked_scatter_对 abt 使用就地操作。2 倍加速(0.136 秒):

active[active.clone()] = ~terminated
Run Code Online (Sandbox Code Playgroud)

异地操作需要 0.161 秒,结果为:

active.masked_scatter_(
        active,
        ~terminated,
    )
Run Code Online (Sandbox Code Playgroud)