numpy rng 线程安全吗?

evo*_*nna 5 python random multithreading numpy thread-safety

我实现了一个使用 numpy 随机生成器来模拟某些过程的函数。这是此类函数的一个最小示例:

def thread_func(cnt, gen):
    s = 0.0
    for _ in range(cnt):
        s += gen.integers(6)
    return s
Run Code Online (Sandbox Code Playgroud)

现在我编写了一个使用python的starmap来调用thread_func的函数。如果我这样写(将相同的 rng 引用传递给所有进程):

from multiprocessing import Pool
import numpy as np    
def evaluate(total_cnt, thread_cnt):
        gen = np.random.default_rng()
        cnt_per_thread = total_cnt // thread_cnt
        with Pool(thread_cnt) as p:
            vals = p.starmap(thread_func, [(cnt_per_thread,gen) for _ in range(thread_cnt)])
        return vals
Run Code Online (Sandbox Code Playgroud)

的结果evaluate(100000, 5)是一个由 5 个相同值组成的数组,例如:

[49870.0, 49870.0, 49870.0, 49870.0, 49870.0]
Run Code Online (Sandbox Code Playgroud)

但是,如果我将不同的 rng 传递给所有进程,例如通过执行以下操作:

vals = p.starmap(thread_func, [(cnt_per_thread,np.random.default_rng()) for _ in range(thread_cnt)])
Run Code Online (Sandbox Code Playgroud)

我得到了预期的结果(5个不同的值),例如:

[49880.0, 49474.0, 50232.0, 50038.0, 50191.0]
Run Code Online (Sandbox Code Playgroud)

为什么会出现这种情况?

Jér*_*ard 2

TL;DR:正如 @MichaelSzczesny 所指出的,主要问题是您使用的进程在具有相同初始状态的同一 RNG 对象的副本上进行操作。


随机数生成器 (RNG) 对象使用称为种子的整数进行初始化,当使用迭代操作(例如(seed * huge_number) % another_huge_number)生成新数字时,种子会被修改。

对于多个线程使用同一个 RNG 对象并不是一个好主意,对其的操作本质上是顺序的。在最好的情况下,如果两个线程以受保护的方式访问它(例如使用临界区),则结果取决于线程的顺序。此外,性能也会受到影响,因为这样做会导致称为缓存行弹跳的效果,从而减慢访问同一对象的线程的执行速度。在最坏的情况下,RNG 对象不受保护,这会导致竞争条件。这样的问题会导致多个线程的种子可能相同,因此结果(应该是随机的)。

CPython 使用称为全局解释器锁 (GIL) 的巨型互斥体来保护对 Python 对象的访问。它可以防止多个线程同时执行 Python 字节码。目标是保护解释器而不是对象状态。Numpy 的许多函数都释放了 GIL,因此代码可以并行扩展。问题是,如果您从同一线程使用它们,则会导致竞争条件。您有责任使用锁来保护 Numpy 对象

就您而言,我无法使用线程重现问题,但可以使用进程重现问题。因此,我认为您在示例中使用了流程。对于流程,您应该使用:

from multiprocessing import Pool
Run Code Online (Sandbox Code Playgroud)

对于线程,您应该使用:

from multiprocessing.pool import ThreadPool as Pool
Run Code Online (Sandbox Code Playgroud)

进程的行为与线程不同,因为它们不对共享对象进行操作(至少默认情况下不会)。相反,进程对对象副本进行操作。由于 RNG 对象的初始状态在所有进程中都相同,因此进程会产生相同的输出

简而言之,请每个线程使用一种不同的 RNG。典型的解决方案是使用它们自己的 RNG 对象创建 N 个线程,然后与它们通信以发送一些工作(例如使用队列)。这称为线程池。另一种选择可能是使用线程本地存储

请注意,Numpy 文档在“多线程生成”部分中提供了一个示例。