与 pyTorch 相比,Jax/Flax(非常)慢的 RNN 前向传递?

Sim*_*n B 5 python performance recurrent-neural-network pytorch jax

我最近在 Jax 中实现了一个两层 GRU 网络,对其性能感到失望(无法使用)。

\n

所以,我尝试与 Pytorch 进行一些速度比较。

\n
最小工作示例
\n

这是我的最小工作示例,输出是在 Google Colab 上使用 GPU 运行时创建的。Colab 中的笔记本

\n
import flax.linen as jnn \nimport jax\nimport torch\nimport torch.nn as tnn\nimport numpy as np \nimport jax.numpy as jnp\n\ndef keyGen(seed):\n    key1 = jax.random.PRNGKey(seed)\n    while True:\n        key1, key2 = jax.random.split(key1)\n        yield key2\nkey = keyGen(1)\n\nhidden_size=200\nseq_length = 1000\nin_features = 6\nout_features = 4\nbatch_size = 8\n\nclass RNN_jax(jnn.Module):\n\n    @jnn.compact\n    def __call__(self, x, carry_gru1, carry_gru2):\n        carry_gru1, x = jnn.GRUCell()(carry_gru1, x)\n        carry_gru2, x = jnn.GRUCell()(carry_gru2, x)\n        x = jnn.Dense(4)(x)\n        x = x/jnp.linalg.norm(x)\n        return x, carry_gru1, carry_gru2\n\nclass RNN_torch(tnn.Module):\n    def __init__(self, batch_size, hidden_size, in_features, out_features):\n        super().__init__()\n\n        self.gru = tnn.GRU(\n            input_size=in_features, \n            hidden_size=hidden_size,\n            num_layers=2\n            )\n        \n        self.dense = tnn.Linear(hidden_size, out_features)\n\n        self.init_carry = torch.zeros((2, batch_size, hidden_size))\n\n    def forward(self, X):\n        X, final_carry = self.gru(X, self.init_carry)\n        X = self.dense(X)\n        return X/X.norm(dim=-1).unsqueeze(-1).repeat((1, 1, 4))\n\nrnn_jax = RNN_jax()\nrnn_torch = RNN_torch(batch_size, hidden_size, in_features, out_features)\n\nXj = jax.random.normal(next(key), (seq_length, batch_size, in_features))\nYj = jax.random.normal(next(key), (seq_length, batch_size, out_features))\nXt = torch.from_numpy(np.array(Xj))\nYt = torch.from_numpy(np.array(Yj))\n\ninitial_carry_gru1 = jnp.zeros((batch_size, hidden_size))\ninitial_carry_gru2 = jnp.zeros((batch_size, hidden_size))\n\nparams = rnn_jax.init(next(key), Xj[0], initial_carry_gru1, initial_carry_gru2)\n\ndef forward(params, X):\n    \n    carry_gru1, carry_gru2 = initial_carry_gru1, initial_carry_gru2\n\n    Yhat = []\n    for x in X: # x.shape = (batch_size, in_features)\n        yhat, carry_gru1, carry_gru2 = rnn_jax.apply(params, x, carry_gru1, carry_gru2)\n        Yhat.append(yhat) # y.shape = (batch_size, out_features)\n\n    #return jnp.concatenate(Y, axis=0)\n\njitted_forward = jax.jit(forward)\n\n
Run Code Online (Sandbox Code Playgroud)\n
结果
\n
# uncompiled jax version\n%time forward(params, Xj)\n
Run Code Online (Sandbox Code Playgroud)\n

CPU times: user 7min 17s, sys: 8.18 s, total: 7min 25s Wall time: 7min 17s

\n
# time for compiling\n%time jitted_forward(params, Xj)\n
Run Code Online (Sandbox Code Playgroud)\n

CPU times: user 8min 9s, sys: 4.46 s, total: 8min 13s Wall time: 8min 12s

\n
# compiled jax version\n%timeit jitted_forward(params, Xj)\n
Run Code Online (Sandbox Code Playgroud)\n

The slowest run took 204.20 times longer than the fastest. This could mean that an intermediate result is being cached. 10000 loops, best of 5: 115 \xc2\xb5s per loop

\n
# torch version\n%timeit lambda: rnn_torch(Xt)\n
Run Code Online (Sandbox Code Playgroud)\n

10000000 loops, best of 5: 65.7 ns per loop

\n
问题
\n

为什么我的 Jax 实现这么慢?我究竟做错了什么?

\n

另外,为什么编译需要这么长时间?序列没那么长..

\n

谢谢 :)

\n

jak*_*vdp 1

JAX 代码编译缓慢的原因是 JIT 编译期间 JAX 展开循环。所以就XLA编译而言,你的函数实际上非常大:你调用rnn_jax.apply()1000 次,而编译时间往往大致是语句数量的二次方。

相比之下,您的 pytorch 函数不使用 Python 循环,因此在幕后它依赖于运行速度更快的矢量化操作。

任何时候你for在 Python 中使用数据循环,你的代码都会很慢:无论你使用 JAX、torch、numpy、pandas 等,这都是事实。我建议找到解决问题的方法在 JAX 中,它依赖于向量化操作,而不是依赖于缓慢的 Python 循环。

  • 哦哇。我想我明白了。您给 rnn_jax.apply(X) 的 X 中的任何附加维度都会以某种方式减少,就好像它是序列维度一样。就像在 pyTorch 中一样。我不知道如何从文档中知道这一点。如果您好奇的话,也许我今天下午会重新进行速度测试并更新结果。 (2认同)