用pytorch打印张量的精确值(浮点精度)

use*_*050 4 python pytorch

我正在尝试打印torch.FloatTensor之类的

a = torch.FloatTensor(5,5)
print(a)
Run Code Online (Sandbox Code Playgroud)

这样我就可以得到一个像这样的价值

0.0000e+00  0.0000e+00  3.2286e-41  9.4448e+21  4.3346e-38
1.2412e-40  1.2313e+00  1.6751e-37  3.1138e-40  9.4460e+21
2.6801e-36  3.5873e-41  9.4463e+21  4.9653e-35  3.9963e-40
9.4454e+21  2.6801e-36  1.2771e-40  9.4460e+21  1.7153e-34
7.7056e-40  9.0090e+15  4.1877e-38  2.9775e-41  1.5695e-43
Run Code Online (Sandbox Code Playgroud)

但我希望获得更准确的值,如10小数点

0.1234567891 + 01

在python中,我可以得到它

print('{:.10f}'.format(a))
Run Code Online (Sandbox Code Playgroud)

但是在张量的情况下,我得到了这个错误

TypeError: unsupported format string passed to torch.FloatTensor.__format__
Run Code Online (Sandbox Code Playgroud)

如何打印准确的张量值?

use*_*754 13

您可以设置精度选项

torch.set_printoptions(precision=10)
Run Code Online (Sandbox Code Playgroud)

文档页面上有更多格式化选项:http://pytorch.org/docs/master/torch.html#creation-ops它与numpys非常相似.


pro*_*sti 7

附带说明一下,此功能取自 numpy。PyTorch 之所以聪明的原因之一是因为它们从 numpy 中汲取了许多好主意。

但是,在 numpy 中默认精度为 8,在 PyTorch 中默认精度为 4。

在此输入图像描述