Pytorch 中的截断时间反向传播 (BPTT)

u2g*_*les 8 truncated backpropagation pytorch

在 pytorch 中,我通过启动反向传播(通过时间)来训练 RNN/GRU/LSTM 网络:

loss.backward()
Run Code Online (Sandbox Code Playgroud)

当序列很长时,我想进行截断的时间反向传播,而不是使用整个序列的正常时间反向传播。

但我在 Pytorch API 中找不到任何参数或函数来设置截断的 BPTT。我错过了吗?我应该在 Pytorch 中自己编写代码吗?

ang*_*ang 1

这是一个例子:

for t in range(T):
   y = lstm(y)
   if T-t == k:
      out.detach()
out.backward()
Run Code Online (Sandbox Code Playgroud)

因此,在此示例中,k是用于控制要展开的时间步长的参数。