Sim*_*n B 7 python nan recurrent-neural-network jax
大家晚上好,
我花了过去 6 个小时尝试调试 Jax 中看似随机出现的 NaN 值。我已经缩小范围,NaN 最初源于损失函数或其梯度。
此处提供了重现错误的最小笔记本https://colab.research.google.com/drive/1uXa-igMm9QBOOl8ZNdK1OkwxRFlLqvZD?usp=sharing
作为 Jax 的一个用例,这也可能很有趣。当只有有限数量的陀螺仪/加速度计测量可用时,我使用 Jax 来解决方向估计任务。在这里,四元数运算的有效实现是很好的。
训练循环一开始很好,但最终会出现分歧
Step 0| Loss: 4.550444602966309 | Time: 13.910547971725464s
Step 1| Loss: 4.110116481781006 | Time: 5.478027105331421s
Step 2| Loss: 3.7159230709075928 | Time: 5.476970911026001s
Step 3| Loss: 3.491917371749878 | Time: 5.474078416824341s
Step 4| Loss: 3.232130765914917 | Time: 5.433410406112671s
Step 5| Loss: 3.095140218734741 | Time: 5.433837413787842s
Step 6| Loss: 2.9580295085906982 | Time: 5.429029941558838s
Step 7| Loss: nan | Time: 5.427825689315796s
Step 8| Loss: nan | Time: 5.463077545166016s
Step 9| Loss: nan | Time: 5.479652643203735s
Run Code Online (Sandbox Code Playgroud)
这可以通过梯度发散来追溯到,如以下代码片段所示
(loss, _), grads = loss_fn(params, X[0], y[0], rnn.reset_carry(bs=2))
grads["params"]["Dense_0"]["bias"] # shape=(bs, out_features)
DeviceArray([[-0.38666773, nan, -1.0433975 , nan],
[ 0.623061 , -0.20950513, 0.8459796 , -0.42356613]], dtype=float32)
Run Code Online (Sandbox Code Playgroud)
启用 nan-debugging 并没有真正帮助,因为它最终只会导致带有许多隐藏痕迹的巨大堆栈跟踪..
from jax.config import config
config.update("jax_debug_nans", True)
Run Code Online (Sandbox Code Playgroud)
任何帮助将非常感激!谢谢 :)
一些方法(在主要文档中得到了很好的记录)可能有效:
float64可以解决问题。更多信息在这里:jax.config.update("jax_enable_x64", True)div令牌表示的分区:from jax import make_jaxpr
# If grad_fn(x) gives you trouble, you can inspect the computation as follows:
grad_fn = jit(value_and_grad(my_forward_prop, argnums=0))
make_jaxpr(grad_fn)(x)
Run Code Online (Sandbox Code Playgroud)
请注意,社区非常活跃,并且已经并正在添加一些支持来诊断NaNs:
希望这可以帮助!
安德烈斯
| 归档时间: |
|
| 查看次数: |
4021 次 |
| 最近记录: |