Sim*_*n B 5 python performance recurrent-neural-network pytorch jax
我最近在 Jax 中实现了一个两层 GRU 网络,对其性能感到失望(无法使用)。
\n这是我的最小工作示例,输出是在 Google Colab 上使用 GPU 运行时创建的。Colab 中的笔记本
\nimport flax.linen as jnn \nimport jax\nimport torch\nimport torch.nn as tnn\nimport numpy as np \nimport jax.numpy as jnp\n\ndef keyGen(seed):\n key1 = jax.random.PRNGKey(seed)\n while True:\n key1, key2 = jax.random.split(key1)\n yield key2\nkey = keyGen(1)\n\nhidden_size=200\nseq_length = 1000\nin_features = 6\nout_features = 4\nbatch_size = 8\n\nclass RNN_jax(jnn.Module):\n\n @jnn.compact\n def __call__(self, x, carry_gru1, carry_gru2):\n carry_gru1, x = jnn.GRUCell()(carry_gru1, x)\n carry_gru2, x = jnn.GRUCell()(carry_gru2, x)\n x = jnn.Dense(4)(x)\n x = x/jnp.linalg.norm(x)\n return x, carry_gru1, carry_gru2\n\nclass RNN_torch(tnn.Module):\n def __init__(self, batch_size, hidden_size, in_features, out_features):\n super().__init__()\n\n self.gru = tnn.GRU(\n input_size=in_features, \n hidden_size=hidden_size,\n num_layers=2\n )\n \n self.dense = tnn.Linear(hidden_size, out_features)\n\n self.init_carry = torch.zeros((2, batch_size, hidden_size))\n\n def forward(self, X):\n X, final_carry = self.gru(X, self.init_carry)\n X = self.dense(X)\n return X/X.norm(dim=-1).unsqueeze(-1).repeat((1, 1, 4))\n\nrnn_jax = RNN_jax()\nrnn_torch = RNN_torch(batch_size, hidden_size, in_features, out_features)\n\nXj = jax.random.normal(next(key), (seq_length, batch_size, in_features))\nYj = jax.random.normal(next(key), (seq_length, batch_size, out_features))\nXt = torch.from_numpy(np.array(Xj))\nYt = torch.from_numpy(np.array(Yj))\n\ninitial_carry_gru1 = jnp.zeros((batch_size, hidden_size))\ninitial_carry_gru2 = jnp.zeros((batch_size, hidden_size))\n\nparams = rnn_jax.init(next(key), Xj[0], initial_carry_gru1, initial_carry_gru2)\n\ndef forward(params, X):\n \n carry_gru1, carry_gru2 = initial_carry_gru1, initial_carry_gru2\n\n Yhat = []\n for x in X: # x.shape = (batch_size, in_features)\n yhat, carry_gru1, carry_gru2 = rnn_jax.apply(params, x, carry_gru1, carry_gru2)\n Yhat.append(yhat) # y.shape = (batch_size, out_features)\n\n #return jnp.concatenate(Y, axis=0)\n\njitted_forward = jax.jit(forward)\n\nRun Code Online (Sandbox Code Playgroud)\n# uncompiled jax version\n%time forward(params, Xj)\nRun Code Online (Sandbox Code Playgroud)\nCPU times: user 7min 17s, sys: 8.18 s, total: 7min 25s Wall time: 7min 17s
# time for compiling\n%time jitted_forward(params, Xj)\nRun Code Online (Sandbox Code Playgroud)\nCPU times: user 8min 9s, sys: 4.46 s, total: 8min 13s Wall time: 8min 12s
# compiled jax version\n%timeit jitted_forward(params, Xj)\nRun Code Online (Sandbox Code Playgroud)\nThe slowest run took 204.20 times longer than the fastest. This could mean that an intermediate result is being cached. 10000 loops, best of 5: 115 \xc2\xb5s per loop
# torch version\n%timeit lambda: rnn_torch(Xt)\nRun Code Online (Sandbox Code Playgroud)\n10000000 loops, best of 5: 65.7 ns per loop
为什么我的 Jax 实现这么慢?我究竟做错了什么?
\n另外,为什么编译需要这么长时间?序列没那么长..
\n谢谢 :)
\nJAX 代码编译缓慢的原因是 JIT 编译期间 JAX 展开循环。所以就XLA编译而言,你的函数实际上非常大:你调用rnn_jax.apply()1000 次,而编译时间往往大致是语句数量的二次方。
相比之下,您的 pytorch 函数不使用 Python 循环,因此在幕后它依赖于运行速度更快的矢量化操作。
任何时候你for在 Python 中使用数据循环,你的代码都会很慢:无论你使用 JAX、torch、numpy、pandas 等,这都是事实。我建议找到解决问题的方法在 JAX 中,它依赖于向量化操作,而不是依赖于缓慢的 Python 循环。
| 归档时间: |
|
| 查看次数: |
1578 次 |
| 最近记录: |