检查 PyTorch 张量在 epsilon 内是否相等

Tom*_*ale 10 pytorch

如何检查两个 PyTorch 张量在语义上是否相等?

鉴于浮点错误,我想知道元素是否仅相差一个小的 epsilon 值。

Tom*_*ale 11

在撰写本文时,这是最新稳定版本 (0.4.1) 中未记录的功能,但文档位于master (unstable)分支中。

torch.allclose() 将返回一个布尔值,指示所有元素差异是否相等,允许存在误差。

此外,还有未记录的isclose()

>>> torch.isclose(torch.Tensor([1]), torch.Tensor([1.00000001]))
tensor([1], dtype=torch.uint8)
Run Code Online (Sandbox Code Playgroud)

  • 现在已经稳定了 (2认同)