Dar*_*ook 4 python python-3.x pytorch
我不明白这一行:
lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs)
Run Code Online (Sandbox Code Playgroud)
没有评论,那么它是一些著名的 Python(或 PyTorch?)习语吗?有人可以解释它的含义,或者展示一种使意图更清晰的不同方式吗?
lprobs
是一个pytorchTensor
,它可以包含任何大小的浮点类型(我怀疑这段代码是否旨在支持 int 或复杂类型)。据我所知,Tensor 类不会覆盖该__ne__
函数。
phi*_*ler 10
它是花哨的索引与布尔掩码的组合,以及一个“技巧”(虽然是设计的目的)来检查NaN
: x != x
hold iff x
is NaN
(对于浮点数,就是这样)。
他们也可以写
lprobs[torch.isnan(lprobs)] = torch.tensor(-math.inf).to(lprobs)
Run Code Online (Sandbox Code Playgroud)
或者,可能更惯用,使用torch.nan_to_num
(但要注意后者对无穷大也有特殊行为)。
上述的非更新变体将是
torch.where(torch.isnan(lprobs), torch.tensor(-math.inf), lprobs)
Run Code Online (Sandbox Code Playgroud)