adj*_*oun 9 python lstm pytorch tensor
这似乎是 PyTorch 中 LSTM 最常见的问题之一,但我仍然无法弄清楚 PyTorch LSTM 的输入形状应该是什么。
即使遵循了几个帖子(1、2、3)并尝试了解决方案,它似乎也不起作用。
背景:我已经对一批大小为 12 的文本序列(可变长度)进行了编码,并且使用pad_packed_sequence
功能对序列进行了填充和打包。MAX_LEN
对于每个序列是 384,序列中的每个标记(或单词)的维度为 768。因此,我的批处理张量可能具有以下形状之一:[12, 384, 768]
或[384, 12, 768]
。
该批次将是我对 PyTorch rnn 模块(此处为 lstm)的输入。
根据用于PyTorch文档LSTMs,其输入尺寸是(seq_len, batch, input_size)
我的理解如下。
seq_len
- 每个输入流中的时间步数(特征向量长度)。
batch
- 每批输入序列的大小。
input_size
- 每个输入标记或时间步长的维度。
lstm = nn.LSTM(input_size=?, hidden_size=?, batch_first=True)
这里的确切值input_size
和hidden_size
值应该是什么?
Mic*_*ngo 15
您已经解释了输入的结构,但是您还没有在输入维度和 LSTM 的预期输入维度之间建立联系。
让我们分解您的输入(为维度分配名称):
batch_size
: 12seq_len
: 384input_size
/ num_features
: 768这意味着input_size
LSTM 的 需要是 768。
在hidden_size
不依赖于你的投入,而是应该LSTM多少功能创建,然后将其用于隐藏状态以及输出,因为这是最后的隐藏状态。您必须决定要为 LSTM 使用多少特征。
最后,对于输入形状,设置batch_first=True
要求输入具有形状[batch_size, seq_len, input_size]
,在您的情况下为[12, 384, 768]
.
import torch
import torch.nn as nn
# Size: [batch_size, seq_len, input_size]
input = torch.randn(12, 384, 768)
lstm = nn.LSTM(input_size=768, hidden_size=512, batch_first=True)
output, _ = lstm(input)
output.size() # => torch.Size([12, 384, 512])
Run Code Online (Sandbox Code Playgroud)
归档时间: |
|
查看次数: |
6422 次 |
最近记录: |