是否有Pytorch内部程序来检测NaN张量中的s?Tensorflow有tf.is_nan和tf.check_numerics操作...... Pytorch在某处有类似的东西吗?我在文档中找不到这样的东西......
我正在寻找一个Pytorch内部例程,因为我希望这发生在GPU和CPU上.这不包括基于numpy的解决方案(如np.isnan(sometensor.numpy()).any())......
nem*_*emo 30
您始终可以利用以下事实nan != nan:
>>> x = torch.tensor([1, 2, np.nan])
tensor([ 1., 2., nan.])
>>> x != x
tensor([ 0, 0, 1], dtype=torch.uint8)
Run Code Online (Sandbox Code Playgroud)
使用pytorch 0.4还有torch.isnan:
>>> torch.isnan(x)
tensor([ 0, 0, 1], dtype=torch.uint8)
Run Code Online (Sandbox Code Playgroud)
Jat*_*aki 19
从PyTorch 0.4.1开始,有一个detect_anomaly上下文管理器,它自动插入等同于assert not torch.isnan(grad).any()向后传播的所有步骤之间的断言.在向后传递期间出现问题时非常有用.
cry*_*ick 16
正如@cleros 在对@nemo 答案的评论中所建议的那样,您可以使用any()运算符将其作为布尔值获取:
torch.isnan(your_tensor).any()
Run Code Online (Sandbox Code Playgroud)
如果你想直接在张量上调用它:
import torch
x = torch.randn(5, 4)
print(x.isnan().any())
Run Code Online (Sandbox Code Playgroud)
出去:
import torch
x = torch.randn(5, 4)
print(x.isnan().any())
tensor(False)
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
14932 次 |
| 最近记录: |