在 numba 并行化中理解这种竞争条件

use*_*964 2 python parallel-processing numba

Numba 文档中有一个关于并行竞争条件的示例

import numba as nb
import numpy as np
@nb.njit(parallel=True)
def prange_wrong_result(x):
    n = x.shape[0]
    y = np.zeros(4)
    for i in nb.prange(n):
        y[:]+= x[i]
    return y
Run Code Online (Sandbox Code Playgroud)

我已经运行了它,它确实输出了异常结果,例如

prange_wrong_result(np.ones(10000))
#array([5264., 5273., 5231., 5234.])
Run Code Online (Sandbox Code Playgroud)

然后我尝试将循环更改为

import numba as nb
import numpy as np
@nb.njit(parallel=True)
def prange_wrong_result(x):
    n = x.shape[0]
    y = np.zeros(4)
    for i in nb.prange(n):
        y+= x[i]
    return y
Run Code Online (Sandbox Code Playgroud)

它输出

prange_wrong_result(np.ones(10000))
#array([10000., 10000., 10000., 10000.])
Run Code Online (Sandbox Code Playgroud)

我已经阅读了一些竞争条件解释。但我还是不明白

  1. 为什么第二个例子没有赛车条件?y[:]=vs 和有什么不一样y=
  2. 为什么第一个例子中四个元素的输出不一样?

MSe*_*ert 7

在您的第一个示例中,您有多个线程/进程共享相同的数组并读取 + 分配给共享数组。在y[:] += x[i]大致相当于:

y[0] += x[i]
y[1] += x[i]
y[2] += x[i]
y[3] += x[i]
Run Code Online (Sandbox Code Playgroud)

事实上,这+=只是读取、加法和赋值操作的语法糖y[0] += x[i],事实上:

_value = y[0]
_value = _value + x[i]
y[0] = _value
Run Code Online (Sandbox Code Playgroud)

循环体由多个线程/进程同时执行,这就是竞争条件出现的地方。 维基百科上关于竞争条件的示例适用于此处:

在此处输入图片说明

这就是返回的数组包含错误值以及每个元素可能不同的原因。因为它只是不确定哪个线程/进程何时运行。所以在某些情况下,一个元素存在竞争条件,有时没有,有时在多个元素上。

然而,numba 开发人员已经在不发生竞争条件的情况下实现了一些受支持的减少。其中之一是y +=。这里重要的是它是变量本身,而不是变量的切片/元素。在这种情况下 numba 做了一些非常聪明的事情。他们为每个线程/进程复制变量的初始值,然后对该副本进行操作。并行循环完成后,他们将复制的值相加。以您的第二个示例为例,假设它使用了 2 个进程,它大致如下所示:

y = np.zeros(4)
y_1 = y.copy()
y_2 = y.copy()
for i in nb.prange(n):
    if is_process_1:
        y_1[:] += x[i]
    if is_process_2:
        y_2[:] += x[i]
y += y_1
y += y_2
Run Code Online (Sandbox Code Playgroud)

由于每个线程都有自己的数组,因此不可能出现竞争条件。为了使 numba 能够推断出这一点,您必须遵循他们的限制。该文档指出 numba 为+=标量和数组 ( y += x[i])创建了无竞争条件的并行代码,但不会为数组元素/切片(y[:] += x[i]y[1] += x[i]) 创建。