Pytorch操作检测NaNs

cle*_*ros 17 pytorch

是否有Pytorch内部程序来检测NaN张量中的s?Tensorflow有tf.is_nantf.check_numerics操作...... Pytorch在某处有类似的东西吗?我在文档中找不到这样的东西......

我正在寻找一个Pytorch内部例程,因为我希望这发生在GPU和CPU上.这不包括基于numpy的解决方案(如np.isnan(sometensor.numpy()).any())......

nem*_*emo 30

您始终可以利用以下事实nan != nan:

>>> x = torch.tensor([1, 2, np.nan])
tensor([  1.,   2., nan.])
>>> x != x
tensor([ 0,  0,  1], dtype=torch.uint8)
Run Code Online (Sandbox Code Playgroud)

使用pytorch 0.4还有torch.isnan:

>>> torch.isnan(x)
tensor([ 0,  0,  1], dtype=torch.uint8)
Run Code Online (Sandbox Code Playgroud)

  • 我可以确认它也可以在GPU上运行.`.any()`然后将它缩减为Python bool.谢谢 :-) (4认同)

Jat*_*aki 19

从PyTorch 0.4.1开始,有一个detect_anomaly上下文管理器,它自动插入等同于assert not torch.isnan(grad).any()向后传播的所有步骤之间的断言.在向后传递期间出现问题时非常有用.


cry*_*ick 16

正如@cleros 在对@nemo 答案的评论中所建议的那样,您可以使用any()运算符将其作为布尔值获取:

torch.isnan(your_tensor).any()
Run Code Online (Sandbox Code Playgroud)


Cha*_*ker 8

如果你想直接在张量上调用它:

import torch

x = torch.randn(5, 4)
print(x.isnan().any())
Run Code Online (Sandbox Code Playgroud)

出去:

import torch
x = torch.randn(5, 4)
print(x.isnan().any())
tensor(False)
Run Code Online (Sandbox Code Playgroud)