了解 PyTorch LSTM 的输入形状

adj*_*oun 9 python lstm pytorch tensor

这似乎是 PyTorch 中 LSTM 最常见的问题之一,但我仍然无法弄清楚 PyTorch LSTM 的输入形状应该是什么。

即使遵循了几个帖子(123)并尝试了解决方案,它似乎也不起作用。

背景:我已经对一批大小为 12 的文本序列(可变长度)进行了编码,并且使用pad_packed_sequence功能对序列进行了填充和打包。MAX_LEN对于每个序列是 384,序列中的每个标记(或单词)的维度为 768。因此,我的批处理张量可能具有以下形状之一:[12, 384, 768][384, 12, 768]

该批次将是我对 PyTorch rnn 模块(此处为 lstm)的输入。

根据用于PyTorch文档LSTMs,其输入尺寸是(seq_len, batch, input_size)我的理解如下。
seq_len- 每个输入流中的时间步数(特征向量长度)。
batch- 每批输入序列的大小。
input_size- 每个输入标记或时间步长的维度。

lstm = nn.LSTM(input_size=?, hidden_size=?, batch_first=True)

这里的确切值input_sizehidden_size值应该是什么?

Mic*_*ngo 15

您已经解释了输入的结构,但是您还没有在输入维度和 LSTM 的预期输入维度之间建立联系。

让我们分解您的输入(为维度分配名称):

  • batch_size: 12
  • seq_len: 384
  • input_size/ num_features: 768

这意味着input_sizeLSTM 的 需要是 768。

hidden_size不依赖于你的投入,而是应该LSTM多少功能创建,然后将其用于隐藏状态以及输出,因为这是最后的隐藏状态。您必须决定要为 LSTM 使用多少特征。

最后,对于输入形状,设置batch_first=True要求输入具有形状[batch_size, seq_len, input_size],在您的情况下为[12, 384, 768].

import torch
import torch.nn as nn

# Size: [batch_size, seq_len, input_size]
input = torch.randn(12, 384, 768)

lstm = nn.LSTM(input_size=768, hidden_size=512, batch_first=True)

output, _ = lstm(input)
output.size()  # => torch.Size([12, 384, 512])
Run Code Online (Sandbox Code Playgroud)