为什么 Cython 比 Python+numpy 慢?

bas*_*pus 5 python numpy cython

我想实现一些快速的凸分析操作 - 近端运算符等。我是 Cython 的新手,认为这将是完成这项工作的正确工具。mwe_py.py我在纯 Python 和 Cython (及以下)中都有实现mwe_c.pyx。然而,当我比较它们时,Python + Numpy 版本明显快于 Cython 版本。为什么是这样?我尝试过使用内存视图,它应该允许更快的索引/操作;但是,性能差异非常明显!任何有关如何修复mwe_c.pyx以下问题以接近“最佳”Cython 代码的建议将不胜感激。

import pyximport; pyximport.install(language_level=3)

import mwe_c
import mwe_py
import numpy as np
from time import time

n = 100000
nreps = 10000
x = np.random.randn(n)
z = np.random.randn(n)
tau = 1.0

t0 = time()
for _ in range(nreps):
    out = mwe_c.prox_translation(mwe_c.prox_two_norm, x, z, tau)
t1 = time()
print(t1 - t0)


t0 = time()
for _ in range(nreps):
    out = mwe_py.prox_translation(mwe_py.prox_two_norm, x, z, tau)
t1 = time()
print(t1 - t0)
Run Code Online (Sandbox Code Playgroud)

分别给出输出:

10.76103401184082  # (seconds)
5.988733291625977  # (seconds)
Run Code Online (Sandbox Code Playgroud)

下边是mwe_py.py

import numpy.linalg as la

def proj_two_norm(x):
    """projection onto l2 unit ball"""
    X = la.norm(x)
    if X <= 1:
        return x
    return x / X

def prox_two_norm(x, tau):
    """proximal mapping of l2 norm with parameter tau"""
    return x - proj_two_norm(x / tau)

def prox_translation(prox_func, x, z, tau=None):
  """Compute prox(f(. - z))(x) where prox_func(x, tau) is prox(tau * f)(x)."""
  if tau is None:
      tau = 1.0
  return z + prox_func(x - z, tau)
Run Code Online (Sandbox Code Playgroud)

最后是mwe_c.pyx

import numpy as np
cimport numpy as np


cdef double [::1] aasubtract(double [::1] x, double [::1] z):
    cdef unsigned int i, m = len(x), n = len(z);
    assert m == n, f"vectors must have the same length"
    cdef double [::1] out = np.copy(x);
    for i in range(n):
        out[i] -= z[i]
    return out


cdef double [::1] vsdivide(double [::1] x, double tau):
    """Divide an array by a scalar element-wise."""
    cdef:
        unsigned int i, n = len(x);
        double [::1] out = np.copy(x);
    for i in range(n):
        out[i] /= tau
    return out


cdef double two_norm(double [::1] x):
    cdef:
        double out = 0.0;
        unsigned int i, n=len(x);
    for i in range(n):
        out = out + x[i]**2
    out = out **.5
    return out


cdef double [::1] proj_two_norm(double [::1] x):
    """project x onto the unit two ball."""
    cdef double x_norm = two_norm(x);
    cdef unsigned int i, n = len(x);
    cdef double [::1] p = np.copy(x);
    if x_norm <= 1:
        return p
    for i in range(n):
        p[i] = p[i] / x_norm
    return p


cpdef double [::1] prox_two_norm(double [::1] x, double tau):
    """double [::1] prox_two_norm(double [::1] x, double tau)"""
    cdef unsigned int i, n = len(x);
    cdef double [::1] out = x.copy(), Px = x.copy();
    Px = proj_two_norm(vsdivide(Px, tau));
    for i in range(n):
        out[i] = out[i] - Px[i]
    return out


cpdef prox_translation(
    prox_func,
    double [::1] x,
    double [::1] z,
    double tau=1.0
):
    cdef:
        unsigned int i, n = len(x);
        double [::1] out = prox_func(aasubtract(x, z), tau);
    for i in range(n):
        out[i] += z[i];
    return out
Run Code Online (Sandbox Code Playgroud)

Jér*_*ard 6

主要问题是将优化的 Numpy 代码与优化程度较低的 Cython 代码进行比较。事实上,Numpy 使用SIMD 指令(例如 x86-64 处理器上的 SSE 和 AVX/AVX2),能够连续计算多个项目。Cython-O2默认情况下使用默认优化级别,该级别不启用任何自动矢量化策略,从而导致标量代码变慢(除非您使用最新版本的 GCC)。您可以使用-O3告诉大多数编译器(例如旧的 GCC 和 Clang)启用自动矢量化。请注意,这不足以生成非常快的代码。事实上,出于兼容性考虑,编译器仅在 x86-64 处理器上使用旧版 SIMD 指令。-mavx-mavx2启用 AVX/AVX-2 指令集,以便生成更快的代码(假设您的计算机支持它)(否则它会崩溃)。-mfma也可能有帮助。-march=native也可用于选择目标平台上可用的最佳指令集。请注意,Numpy 在运行时(部分)执行此检查(感谢 GCC 特定的 C 功能)。

第二个主要问题是,这out = out + x[i]**2会导致编译器在不违反 IEEE-754 标准的情况下无法优化循环携带的依赖链。事实上,需要执行很长的加法链,并且处理器执行该加法的速度不能比用当前代码串行执行每个加法指令更快。问题是两个浮点数相加会产生相当大的延迟(在相当现代的 x86-64 处理器上通常为 3 到 4 个周期)。这意味着处理器无法流水线指令。事实上,现代处理器通常可以并行执行两个加法(每个核心),但电流循环阻止了这一点。最后,这个循环完全受延迟限制。您可以通过手动展开循环来解决此问题。

使用-ffast-math可以帮助编译器进行此类优化,但代价是违反 IEEE-754 标准。如果您使用此选项,则无需使用 NaN 数等特殊值或某些运算。有关更多信息,请阅读gcc 的 ffast-math 实际上做了什么?

此外,请注意,数组副本很昂贵,并且我不确定是否需要所有副本。您可以创建一个新的空数组并填充它,而不是对数组的副本进行操作。这会更快,特别是对于大数组。

最后,分裂速度很慢。请考虑乘以倒数。由于 IEEE-754 标准,编译器无法执行此优化,但您可以轻松执行此优化。话虽这么说,您需要确保这在您的情况下没问题,因为它可能会稍微改变结果。使用-ffast-math也应该自动解决这个问题。

请注意,Numpy 的许多开发人员都知道编译器和处理器的工作原理,因此他们已经进行了手动优化,以生成快速代码(就像我多次所做的那样)。在处理大型数组时,除非合并循环或使用多线程,否则很难击败 Numpy。事实上,与计算单元相比,现在的 RAM 相当慢,并且 Numpy 创建了许多临时数组。Cython 可用于避免创建大多数临时数组。