flax (google) 和 dm-haiku (deepmind) 之间的主要区别是什么?

dis*_*ort 13 neural-network deep-learning jax flax dm-haiku

亚麻dm-haiku之间的主要区别是什么?

从他们的描述来看:

  • Flax,JAX 的神经网络库
  • Haiku,受 Sonnet 启发的 JAX 神经网络库

问题

我应该选择哪一个基于 jax 的库来实现,比如说DeepSpeech模型(由 CNN 层 + LSTM 层 + FC 组成)和 ctc-loss?


UPD

找到dm-haiku的开发者关于差异的解释:

Flax 包含更多的电池,并配有优化器、混合精度和一些训练循环(我听说这些是解耦的,你可以根据需要使用尽可能多或尽可能少的量)。Haiku 的目标只是解决 NN 模块和状态管理,它将问题的其他部分留给其他库(例如用于优化的 optax)。

Haiku 被设计为 Sonnet(一个 TF NN 库)到 JAX 的端口。因此,如果(像 DeepMind 一样)您有大量可能想要在 JAX 中使用的 Sonnet+TF 代码,并且您希望尽可能轻松地(在任一方向上)迁移该代码,Haiku 是一个更好的选择。

我认为否则这取决于个人喜好。在 Alphabet 中,每个库都有数百名研究人员使用,所以我认为无论哪种方式都不会出错。在 DeepMind,我们对俳句进行了标准化,因为它对我们来说有意义。我建议查看两个库提供的示例代码,看看哪个符合您构建实验的偏好。我想如果您将来改变主意,您会发现将代码从一个库移动到另一个库并不是很复杂。


原来的问题仍然相关。

Rob*_*bin 7

我最近遇到了同样的问题,我更喜欢 Haiku,因为我认为它们的实现(参见Flax Dense()Haiku Linear())更接近原始的 JAX 精神(即链接initpredict函数并跟踪 Pytree 中的参数)让我更容易修改东西。

但如果你不想深入修改,最好的选择是找到一篇关于 CNN + LSTMs with Flax/Haiku 的好博客文章并坚持下去。我的总体看法是,即使我更喜欢更模块化的 Haiku (+ Optax + Rlax + Chex + ...) 构建方式,这两个库也非常接近。

  • 在我看来,JAX/FLAX/Haiku 比 Tensorflow/Keras 更容易。我不知道 pytorch,但无论如何我不会将它与 FLAX 进行比较。对我来说,第一个选择是 TensorFlow、PyTorch 或 JAX,然后,如果您选择 JAX,第二个选择是 Flax/Trax/Haiku/...(或者 Keras/Theano...如果您选择 TensorFlow)。但这可能是一个过时的观点! (3认同)