x[x!=x] 是什么意思?

Dar*_*ook 4 python python-3.x pytorch

我不明白这一行

lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs)
Run Code Online (Sandbox Code Playgroud)

没有评论,那么它是一些著名的 Python(或 PyTorch?)习语吗?有人可以解释它的含义,或者展示一种使意图更清晰的不同方式吗?

lprobs是一个pytorchTensor,它可以包含任何大小的浮点类型(我怀疑这段代码是否旨在支持 int 或复杂类型)。据我所知,Tensor 类不会覆盖该__ne__函数。

phi*_*ler 10

它是花哨的索引与布尔掩码的组合,以及一个“技巧”(虽然是设计的目的)来检查NaN: x != xhold iff xis NaN(对于浮点数,就是这样)。

他们也可以写

lprobs[torch.isnan(lprobs)] = torch.tensor(-math.inf).to(lprobs)
Run Code Online (Sandbox Code Playgroud)

或者,可能更惯用,使用torch.nan_to_num(但要注意后者对无穷大也有特殊行为)。

上述的非更新变体将是

torch.where(torch.isnan(lprobs), torch.tensor(-math.inf), lprobs)
Run Code Online (Sandbox Code Playgroud)