为什么每 10 万次迭代打印一次会破坏 numba 性能?

Bas*_*asj 1 python printing jit numba

为什么这段代码print每 100k 次迭代一次(即只打印40 行!)需要 50 秒才能运行:

import numpy as np
from numba import jit

@jit
def doit():
    A = np.random.random(4*1000*1000)
    n = 300
    Q = np.zeros(len(A)-n)
    for i in range(len(Q)):
        Q[i] = np.sum(A[i:i+n] <= A[i+n])
        if i % 100000 == 0:  # print the progress once every 100k iterations
            print("%i %.2f %% already done. " % (i, i * 100.0 / len(A)))

doit()
Run Code Online (Sandbox Code Playgroud)

而如果没有print只需要 2.4 秒

import numpy as np
from numba import jit
@jit
def doit():
    A = np.random.random(4*1000*1000)
    n = 300
    Q = np.zeros(len(A)-n)
    for i in range(len(Q)):
        Q[i] = np.sum(A[i:i+n] <= A[i+n])
doit()
Run Code Online (Sandbox Code Playgroud)

这是一个普遍事实,print真的可以消除 的好处numba吗?

Jos*_*del 5

如果您尝试使用@njitor编译它@jit(nopython=True),您将看到它正在从异常中以对象模式进行编译。这个版本在我的机器上运行大约 1 秒,打印语句如下:

import numpy as np
from numba import jit

@jit(nopython=True)
def doit():
    A = np.random.random(4*1000*1000)
    n = 300
    Q = np.zeros(len(A)-n)
    for i in range(len(Q)):
        Q[i] = np.sum(A[i:i+n] <= A[i+n])
        if i % 100000 == 0:  # print the progress once every 100k iterations
            print(i , "(",  i * 100.0 / len(A), '% already done)')
Run Code Online (Sandbox Code Playgroud)

一般来说,如果您发现 numba 函数性能不佳,那是因为您正在 python 对象模式下进行编译,因此总是 putnopython=True是一个很好的做法,除非您真的想在 python 对象模式下使用它,因为如果它遇到了一些编译器无法编译为机器代码的语法。Numba 确实做了一些循环提升,但这在性能方面更难推理。

看:

http://numba.pydata.org/numba-doc/latest/user/5minguide.html#what-is-nopython-mode

  • 这是“print”函数中不支持的字符串格式。 (2认同)