使用 int dtype 进行 numpy 数组计算时出错(在需要时无法自动将 dtype 转换为 64 位)

SLh*_*ark 5 python arrays integer numpy numpy-ndarray

当计算的输入是具有 32 位整数数据类型的 numpy 数组时,我遇到了 numpy 计算不正确的问题,但输出包含需要 64 位表示的较大数字。

这是一个最小的工作示例:

arr = np.ones(5, dtype=int) * (2**24 + 300)  # arr.dtype defaults to 'int32'

# Following comment from @hpaulj I changed the first line, which was originally:
# arr = np.zeros(5, dtype=int) 
# arr[:] = 2**24 + 300

single_value_calc = 2**8 * (2**24 + 300)
numpy_calc = 2**8 * arr

print(single_value_calc)
print(numpy_calc[0])

# RESULTS
4295044096
76800
Run Code Online (Sandbox Code Playgroud)

所需的输出是 numpy 数组包含正确的值 4295044096,这需要 64 位来表示它。即我希望 numpy 数组在输出需要时自动从 int32 向上转换为 int64,而不是保持 32 位输出并在超过 2^32 的值后返回到 0。

当然,我可以通过强制 int64 表示来手动修复问题:

numpy_calc2 = 2**8 * arr.astype('int64')
Run Code Online (Sandbox Code Playgroud)

但这对于通用代码来说是不可取的,因为在某些情况下而不是所有情况下,输出只需要 64 位表示(即保存大数)。在我的用例中,性能至关重要,因此每次都强制向上转换成本很高。

这是 numpy 数组的预期行为吗?如果是这样,请问是否有干净,高性能的解决方案?

And*_*eak 3

numpy 中的类型转换和提升相当复杂,有时甚至令人惊讶。Sebastian Berg 最近的这篇非官方文章解释了该主题的一些细微差别(主要集中在标量和 0d 数组)。

引用这个文档:

Python 整数和浮点数

请注意,Python 整数的处理方式与 numpy 完全相同。然而,它们的特殊之处在于它们没有明确关联的数据类型。基于值的逻辑,如此处所述,对于 python 整数和浮点数似乎很有用,允许:

arr = np.arange(10, dtype=np.int8)
arr += 1
# or:
res = arr + 1
res.dtype == np.int8
Run Code Online (Sandbox Code Playgroud)

这确保不会发生向上转型(例如内存使用量更高)。

(强调我的。)

另请参阅Allan Haldane 建议 C 风格类型强制的要点,链接自上一个文档:

目前,当二元运算涉及两个数据类型时,numpy 的原则是“输出数据类型的范围覆盖两个输入数据类型的范围”,而当涉及单个数据类型时,永远不会进行任何强制转换。

(再次强调我的。)

所以我的理解是 numpy 标量和数组的提升规则不同,主要是因为检查数组内的每个元素来确定是否可以安全地进行转换是不可行的。再次来自前一个文档:

基于标量的规则

与无法检查所有值的数组不同,对于标量(和 0 维数组),会检查值。

这意味着您可以np.int64从一开始就使用它以确保安全(如果您在 Linux 上,那么dtype=int实际上会自行执行此操作),或者在可疑操作之前检查数组的最大值并确定是否必须提升根据具体情况自行确定 dtype。我知道如果您正在进行大量计算,这可能不可行,但考虑到 numpy 当前的类型提升规则,我不认为有办法解决这个问题。