小编Sim*_*n B的帖子

Jax - 调试 NaN 值

大家晚上好,

我花了过去 6 个小时尝试调试 Jax 中看似随机出现的 NaN 值。我已经缩小范围,NaN 最初源于损失函数或其梯度。

此处提供了重现错误的最小笔记本https://colab.research.google.com/drive/1uXa-igMm9QBOOl8ZNdK1OkwxRFlLqvZD?usp=sharing

作为 Jax 的一个用例,这也可能很有趣。当只有有限数量的陀螺仪/加速度计测量可用时,我使用 Jax 来解决方向估计任务。在这里,四元数运算的有效实现是很好的。

训练循环一开始很好,但最终会出现分歧

Step 0| Loss: 4.550444602966309 | Time: 13.910547971725464s
Step 1| Loss: 4.110116481781006 | Time: 5.478027105331421s
Step 2| Loss: 3.7159230709075928 | Time: 5.476970911026001s
Step 3| Loss: 3.491917371749878 | Time: 5.474078416824341s
Step 4| Loss: 3.232130765914917 | Time: 5.433410406112671s
Step 5| Loss: 3.095140218734741 | Time: 5.433837413787842s
Step 6| Loss: 2.9580295085906982 | Time: 5.429029941558838s
Step 7| Loss: nan | Time: 5.427825689315796s
Step 8| Loss: …
Run Code Online (Sandbox Code Playgroud)

python nan recurrent-neural-network jax

7
推荐指数
1
解决办法
4021
查看次数

根据列添加缺失的行

我给出了以下 df

df = pd.DataFrame(data = {'day': [1, 1, 1, 2, 2, 3], 'pos': 2*[1, 14, 18], 'value': 2*[1, 2, 3]}    
df
Run Code Online (Sandbox Code Playgroud)
    day pos value
0   1   1   1
1   1   14  2
2   1   18  3
3   2   1   1
4   2   14  2
5   3   18  3
Run Code Online (Sandbox Code Playgroud)

我想填写行,以便每天都有列“pos”的所有可能值

想要的结果:

    day pos value
0   1   1   1.0
1   1   14  2.0
2   1   18  3.0
3   2   1   1.0
4   2   14  2.0
5   2   18  NaN
6   3 …
Run Code Online (Sandbox Code Playgroud)

python missing-data pandas

5
推荐指数
1
解决办法
52
查看次数

与 pyTorch 相比,Jax/Flax(非常)慢的 RNN 前向传递?

我最近在 Jax 中实现了一个两层 GRU 网络,对其性能感到失望(无法使用)。

\n

所以,我尝试与 Pytorch 进行一些速度比较。

\n
最小工作示例
\n

这是我的最小工作示例,输出是在 Google Colab 上使用 GPU 运行时创建的。Colab 中的笔记本

\n
import flax.linen as jnn \nimport jax\nimport torch\nimport torch.nn as tnn\nimport numpy as np \nimport jax.numpy as jnp\n\ndef keyGen(seed):\n    key1 = jax.random.PRNGKey(seed)\n    while True:\n        key1, key2 = jax.random.split(key1)\n        yield key2\nkey = keyGen(1)\n\nhidden_size=200\nseq_length = 1000\nin_features = 6\nout_features = 4\nbatch_size = 8\n\nclass RNN_jax(jnn.Module):\n\n    @jnn.compact\n    def __call__(self, x, carry_gru1, carry_gru2):\n        carry_gru1, x = jnn.GRUCell()(carry_gru1, x)\n        carry_gru2, x = jnn.GRUCell()(carry_gru2, x)\n        x = jnn.Dense(4)(x)\n …
Run Code Online (Sandbox Code Playgroud)

python performance recurrent-neural-network pytorch jax

5
推荐指数
1
解决办法
1578
查看次数