NumPyro vs Pyro:为什么前者快 100 倍,我什么时候应该使用后者?

GuS*_*uku 5 pytorch pyro.ai probabilistic-programming tensorflow-probability numpyro

来自 Pytorch-Pyro 的网站

我们很高兴地宣布发布 NumPyro,这是一个 NumPy 支持的 Pyro,使用 JAX 进行自动微分和 JIT 编译,HMC 和 NUTS 的速度提高了 100 倍以上!

我的问题:

  1. NumPyro(超过 Pyro)的性能提升(有时是340 倍或 2 倍)究竟来自哪里?
  2. 更重要的是,为什么(而不是,在哪里)我会继续使用 Pyro?

额外的:

  1. 与 Tensorflow Probability 相比,我应该如何查看 NumPyro 的性能和特性,以决定在何处使用哪个?

Car*_*uza 6

这是个好问题。我刚刚在Pyro 的专用论坛上问了同样的问题。这是他们的一位核心开发人员的回答:“Pyro 中有很多很酷的东西没有出现在 NumPyro 中,例如,请参阅 Pyro 文档中的贡献代码部分。对我来说,在开发时,调试 PyTorch 代码要容易得多比 Jax 代码(尽管 Jax 团队在最近的版本中付出了很大的努力来帮助调试)。因此,要实现新的推理算法,我在 Pyro 中工作更容易。”

  • Jax 似乎是目前最轻的自动微分库。因此,这些改进大部分可能集中在 HMC/NUTS 上,而不是变分推理上。相反,PyTorch 构建的计算图对于 HMC/NUTS 效率来说可能有点臃肿,但有可能提供改进的 SVI 功能。因此,NumPyro 可能是传统贝叶斯统计的理想选择,而 Pyro 可能是贝叶斯机器学习、贝叶斯神经网络等的理想选择。此外,创新这些模型的研究人员往往拥有深度学习经验,因此 PyTorch 不一定是探索该库的障碍。 (4认同)