为什么 PyTorch C++ 扩展比其等效的 numba 版本慢得多?

blu*_*e10 10 c++ numba torch pytorch

我一直在尝试各种选项来加速 PyTorch 中的一些 for 循环逻辑。执行此操作的两个明显选项是使用numba编写自定义 C++ 扩展

\n

作为一个例子,我从数字信号处理中选择了“可变长度延迟线”。使用简单的 Python for 循环可以简单但低效地编写此代码:

\n
def delay_line(samples, delays):\n    """\n    :param samples: Float tensor of shape (N,)\n    :param delays: Int tensor of shape (N,)\n    \n    The goal is basically to mix each `samples[i]` with the delayed sample\n    specified by a per-sample `delays[i]`.\n    """\n    for i in range(len(samples)):\n        delay = int(delays[i].item())\n        index_delayed = i - delay\n        if index_delayed < 0:\n            index_delayed = 0\n\n        samples[i] = 0.5 * (samples[i] + samples[index_delayed])\n
Run Code Online (Sandbox Code Playgroud)\n

知道 for 循环在 Python 中的执行情况有多糟糕,我希望通过在 C++ 中实现相同的循环可以获得明显更好的性能。按照教程,我想出了从 Python 到 C++ 的直译:

\n
void delay_line(torch::Tensor samples, torch::Tensor delays) {\n\n  int64_t input_size = samples.size(-1);\n\n  for (int64_t i = 0; i < input_size; ++i) {\n    int64_t delay = delays[i].item<int64_t>();\n    int64_t index_delayed = i - delay;\n    if (index_delayed < 0) {\n      index_delayed = 0;\n    }\n\n    samples[i] = 0.5 * (samples[i] + samples[index_delayed]);\n  }\n}\n
Run Code Online (Sandbox Code Playgroud)\n

我还采用了 Python 函数并将其包装到各种 jit 装饰器中,以获得该函数的 numba 和 torchscript 版本(有关 numba 包装的详细信息,请参阅我的其他答案)。然后,我对所有版本执行了基准测试,这还取决于张量是驻留在 CPU 还是 GPU 上。结果相当令人惊讶:

\n
\xe2\x95\xad\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\xac\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\xac\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x95\xae\n\xe2\x94\x82 Method       \xe2\x94\x82 Device   \xe2\x94\x82   Median time [ms] \xe2\x94\x82\n\xe2\x94\x9c\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\xbc\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\xbc\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\xa4\n\xe2\x94\x82 plain_python \xe2\x94\x82 CPU      \xe2\x94\x82             13.481 \xe2\x94\x82\n\xe2\x94\x82 torchscript  \xe2\x94\x82 CPU      \xe2\x94\x82              6.318 \xe2\x94\x82\n\xe2\x94\x82 numba        \xe2\x94\x82 CPU      \xe2\x94\x82              0.016 \xe2\x94\x82\n\xe2\x94\x82 cpp          \xe2\x94\x82 CPU      \xe2\x94\x82              9.056 \xe2\x94\x82\n\xe2\x94\x82 plain_python \xe2\x94\x82 GPU      \xe2\x94\x82             45.412 \xe2\x94\x82\n\xe2\x94\x82 torchscript  \xe2\x94\x82 GPU      \xe2\x94\x82             47.809 \xe2\x94\x82\n\xe2\x94\x82 numba        \xe2\x94\x82 GPU      \xe2\x94\x82              0.236 \xe2\x94\x82\n\xe2\x94\x82 cpp          \xe2\x94\x82 GPU      \xe2\x94\x82             31.145 \xe2\x94\x82\n\xe2\x95\xb0\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\xb4\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\xb4\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x95\xaf\n
Run Code Online (Sandbox Code Playgroud)\n

注:样本缓冲区大小固定为1024;结果是 100 次执行的中位数,以忽略初始 jit 开销中的工件;输入数据创建并将其移动到设备不包括在测量范围内;完整的基准测试脚本要点

\n

最显着的结果是:C++ 变体似乎出奇地慢。numba 快两个数量级的事实表明问题确实可以更快地解决。事实上,C++ 变体仍然非常接近众所周知的缓慢的 Python for 循环,这可能表明有些事情不太正确。

\n

我想知道什么可以解释 C++ 扩展的糟糕性能。第一个想到的就是缺少优化。不过,我已经确保编译使用了优化。从 切换-O2-O3也没有什么区别。

\n

为了隔离 pybind11 函数调用的开销,我用空函数体替换了 C++ 函数,即不执行任何操作。这将时间减少到 2-3 \xce\xbcs,这意味着时间确实花费在该特定函数体上。

\n

有什么想法为什么我观察到如此差的性能吗?我可以在 C++ 方面做些什么来匹配 numba 实现的性能吗?

\n

额外问题:GPU 版本是否会比 CPU 版本慢很多?

\n