numpy中使用单精度浮点数的不便

Hyu*_*ang 1 python floating-point precision numpy

在numpy中使用单精度(float32)编写代码时,太难写了。

首先,单精度浮子的寿命太长。我们必须按如下方式输入所有变量。

a = np.float32(5)
Run Code Online (Sandbox Code Playgroud)

但其他一些语言使用更简单的表示。

a = 5.f
Run Code Online (Sandbox Code Playgroud)

其次,艺术手术也不方便。

b = np.int32(5)+np.float32(5)
Run Code Online (Sandbox Code Playgroud)

我期望的类型b是 isnumpy.float32但它是numpy.float64

当然,

b = np.add(np.int32(5), np.float32(5), dtype=np.float32)
Run Code Online (Sandbox Code Playgroud)

返回我想要的。但要替换所有操作就太长了。

有没有更简单的方法在 numpy 中使用单精度?

MSe*_*ert 6

问题在于,当您在操作中使用不同类型时,NumPy 会提升类型。float32float32当其他数字操作数的数据类型为:

  • float32或更少
  • int16或更少
  • uint16或更少

如果另一个操作数具有另一个数据类型,则结果将为float64(或者complex如果另一个操作数是复数)。+上面列出的数据类型不是最常见的,因此几乎任何使用标准运算符, -, /, , ...的操作(特别是当另一个是 Python 整数/浮点时)*都会将您的float32值提升为float64.

不幸的是,您无能为力来避免这种情况。在很多情况下,NumPy 这样做是可以的,因为:

  • 大多数架构可以像处理单精度浮点数一样快地处理双精度。Python 中的算术运算在 Python 类型上运行速度很快,但在其他类型上运行速度较慢。
import numpy as np
a32 = np.float32(1)
a64 = np.float64(1)
a = 1.
%timeit [a32 + a32 for _ in range(20000)]  # 100 loops, best of 3: 4.58 ms per loop
%timeit [a64 + a64 for _ in range(20000)]  # 100 loops, best of 3: 4.83 ms per loop
%timeit [a + a for _ in range(20000)]      # 100 loops, best of 3: 2.72 ms per loop
Run Code Online (Sandbox Code Playgroud)
  • Python 类型的开销非常大,标量双精度浮点数的内存开销几乎可以忽略不计。
import sys
import numpy as np
    
sys.getsizeof(np.float32(1))  # 28
sys.getsizeof(np.float64(1))  # 32
sys.getsizeof(1.)             # 24  # that's also a double on my computer!
Run Code Online (Sandbox Code Playgroud)

然而,如果您有巨大的数组并且会遇到内存问题,或者如果您与其他需要单精度浮点的库(机器学习、GPU 等)交互,那么使用单精度浮点是有意义的。

但正如上面提到的,你几乎总是会反对强制规则,这可以防止你遇到意想不到的问题。

这个例子int32 + float32实际上是一个很好的例子!您期望结果是float32- 但有一个问题:您不能将 every 表示int32float32

np.iinfo(np.int32(1))             # iinfo(min=-2147483648, max=2147483647, dtype=int32)
int(np.float32(2147483647))       # 2147483648
np.int32(np.float32(2147483647))  # -2147483648
Run Code Online (Sandbox Code Playgroud)

是的,只需将值转换为单精度浮点数并将其转换回整数,就可以改变它的值。这就是 NumPy 使用双精度的原因,这样您就不会得到意外的结果!这就是为什么你需要强制 NumPy做一些可能错误的事情(从一般用户的角度来看)。


由于(据我所知)没有方法可以限制 Numpy 的类型提升,因此您必须发明自己的方法。

例如,您可以创建一个包装 NumPy 数组的类,并使用特殊方法来实现运算符的 dtype-d 函数:

import numpy as np

class Arr32:
    def __init__(self, arr):
        self.arr = arr
        
    def __add__(self, other):
        other_arr = other
        if isinstance(other, Arr32):
            other_arr = other.arr
        return self.__class__(np.add(self.arr, other_arr, dtype=np.float32))
        
    def __sub__(self, other):
        other_arr = other
        if isinstance(other, Arr32):
            other_arr = other.arr
        return self.__class__(np.subtract(self.arr, other_arr, dtype=np.float32))
        
    def __mul__(self, other):
        other_arr = other
        if isinstance(other, Arr32):
            other_arr = other.arr
        return self.__class__(np.multiply(self.arr, other_arr, dtype=np.float32))
        
    def __truediv__(self, other):
        other_arr = other
        if isinstance(other, Arr32):
            other_arr = other.arr
        return self.__class__(np.divide(self.arr, other_arr, dtype=np.float32))
Run Code Online (Sandbox Code Playgroud)

但这仅实现了 NumPy 功能的一小部分,并且很快就会产生大量可能被遗忘的代码和边缘情况。现在可能有更聪明的方法使用__array_ufunc____array_function__,但我自己没有使用过这些,所以我无法评论工作量或适用性。

所以我的首选解决方案是为所需的函数创建辅助函数:

import numpy as np

def arr32(a):
    return np.float32(a)

def add32(a1, a2):
    return np.add(a1, a2, dtype=np.float32)

def sub32(a1, a2):
    return np.subtract(a1, a2, dtype=np.float32)

def mul32(a1, a2):
    return np.multiply(a1, a2, dtype=np.float32)

def div32(a1, a2):
    return np.divide(a1, a2, dtype=np.float32)
Run Code Online (Sandbox Code Playgroud)

或者仅使用就地操作,因为这些操作不会提升类型:

>>> import numpy as np

>>> arr = np.float32([1,2,3])
>>> arr += 2
>>> arr *= 3
>>> arr
array([ 9., 12., 15.], dtype=float32)
Run Code Online (Sandbox Code Playgroud)

  • “大多数架构可以像单精度浮点一样快速地处理双精度”,如果您正在考虑执行单个操作的时间,我认为这是正确的,但缓存局部性和 SIMD 指令的存在可以处理两倍于它们所能处理的 float32立即处理 float64 意味着在实践中,使用“float32”数组进行操作可能会更快。尝试对一百万个随机浮点数的 1d NumPy 数组“x”使用“%timeit x + x”。我使用“float64”每个循环得到 788 us,使用“float32”每个循环得到 288 us。当然,您的结果可能会有所不同。 (2认同)