use*_*232 5 deep-learning lstm recurrent-neural-network pytorch
我对在 Pytorch 中使用 LSTM 很陌生,我正在尝试创建一个模型,该模型获得一个大小为 42 的张量和一个 62 的序列。(所以每个 62 张量 a 大小为 42)。这意味着我在一个序列中有 62 个张量。每个张量的大小为 42。(形状为 [62,42]。调用此输入张量。
我想用这个来预测一个 1 的张量,序列为 8(所以大小为 1 张量和 8 个序列)。这意味着在大小为 1 的序列中有 8 个张量。称这个标签张量。
这些张量之间的连接是这样的:输入张量由列组成:A1 A2 A3 ...... A42 而标签张量如果更像:A3
我想展示的是,如果需要,标签张量可以在所有地方用零填充,而不是 A3 的值,因此它可以达到 42 的长度。
我怎样才能做到这一点?因为从我从 Pytorch 文档中阅读的内容来看,我只能以相同的比率进行预测(1 点预测为 1),而我想从 42 的张量中进行预测,序列为 62,张量为 1,序列为 8。是可行吗?我是否需要将预测的张量从 1 填充到大小为 42?谢谢!
例如,一个好的解决方案是使用 seq2seq
如果我正确理解你的问题,给定一个长度为 62 的序列,你想预测一个长度为 8 的序列,从某种意义上说,你的输出顺序很重要,如果你正在进行一些时间序列预测,就是这种情况)。在这种情况下,使用 seq2seq 模型将是一个不错的选择,这是此链接的教程。全球范围内,您需要实现一个编码器和一个解码器,下面是这种实现的示例:
class EncoderRNN(nn.Module):
def __init__(self, input_dim=42, hidden_dim=100):
super(EncoderRNN, self).__init__()
self.hidden_size = hidden_size
self.lstm = nn.LSTM(input_dim, hidden_dim)
def forward(self, input, hidden):
output, hidden = self.lstm(input, hidden)
return output, hidden
def initHidden(self):
return torch.zeros(1, 1, self.hidden_size, device=device)
class DecoderRNN(nn.Module):
def __init__(self, hidden_dim, output_dim):
super(DecoderRNN, self).__init__()
self.hidden_dim = hidden_dim
self.lstm = nn.LSTM(hidden_dim, hidden_dim)
self.out = nn.Linear(hidden_dim, output_dim)
self.softmax = nn.LogSoftmax(dim=1)
def forward(self, input, hidden):
output, hidden = self.lstm(input, hidden)
output = self.softmax(self.out(output[0]))
return output, hidden
def initHidden(self):
return torch.zeros(1, 1, self.hidden_size, device=device)
Run Code Online (Sandbox Code Playgroud)
如果 8 个输出的顺序并不重要,那么您可以简单地在 LSTM 层之后添加一个具有 8 个单元的线性层。在这种情况下您可以直接使用此代码
class Net(nn.Module):
def __init__(self, hidden_dim=100, input_dim=42, output_size=8):
super(Net, self).__init__()
self.hidden_dim = hidden_dim
self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
# The linear layer that maps from hidden state space to tag space
self.fc = nn.Linear(hidden_dim, output_size_size)
def forward(self, seq):
lstm_out, _ = self.lstm(seq)
output = self.fc(lstm_out)
return output
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
442 次 |
| 最近记录: |