我有一个掩码active
,可以跟踪在循环过程中仍未终止的批次。它的维度是[batch_full,]
,它的真实条目显示当前步骤中仍需要使用哪些元素。循环过程生成另一个掩码 ,terminated
它具有与掩码中的真实值一样多的元素active
。现在,我想从 中取出值~terminated
并将它们放回 中active
,但要使用正确的索引。基本上我想做:
import torch
active = torch.ones([4,], dtype=torch.bool)
active[:2] = torch.tensor(False)
terminated = torch.tensor([True, False])
active[active] = ~terminated
print(active) # expected [F, F, F, T]
Run Code Online (Sandbox Code Playgroud)
但是,我收到错误:
RuntimeError:不支持的操作:输入张量和写入张量的某些元素引用单个内存位置。请在执行操作之前克隆()张量。
如何才能有效地进行上述操作呢?
有一些解决方案,我还将给出timeit
在 2021 款 MacBook Pro 上通过 10k 次重复测量的速度。
最简单的解决方案,耗时 0.260s:
active[active.clone()] = ~terminated
Run Code Online (Sandbox Code Playgroud)
我们可以masked_scatter_
对 abt 使用就地操作。2 倍加速(0.136 秒):
active[active.clone()] = ~terminated
Run Code Online (Sandbox Code Playgroud)
异地操作需要 0.161 秒,结果为:
active.masked_scatter_(
active,
~terminated,
)
Run Code Online (Sandbox Code Playgroud)
归档时间: |
|
查看次数: |
864 次 |
最近记录: |