AttributeError: 'tuple' 对象没有属性 'dim',当向 Pytorch LSTM 网络提供输入时

Jor*_*Dik 5 python tuples torch lstm pytorch

我正在尝试运行以下代码:

import matplotlib.pylab as plt
import numpy as np
import torch
import torch.nn as nn

class LSTM(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(LSTM, self).__init__()

        self.lstm = nn.LSTM(input_shape, 12)
        self.hidden2tag = nn.Linear(12, n_actions)

    def forward(self, x):
        out = self.lstm(x)
        out = self.hidden2tag(out)
        return out


state = [(1,2,3,4,5),(2,3,4,5,6),(3,4,5,6,7),(4,5,6,7,8),(5,6,7,8,9),(6,7,8,9,0)]

device = torch.device("cuda")
net = LSTM(5, 3).to(device)

state_v = torch.FloatTensor(state).to(device)

q_vals_v = net(state_v.view(1, state_v.shape[0], state_v.shape[1]))
_, action = int(torch.max(q_vals_v, dim=1).item())
Run Code Online (Sandbox Code Playgroud)

并返回此错误:

import matplotlib.pylab as plt
import numpy as np
import torch
import torch.nn as nn

class LSTM(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(LSTM, self).__init__()

        self.lstm = nn.LSTM(input_shape, 12)
        self.hidden2tag = nn.Linear(12, n_actions)

    def forward(self, x):
        out = self.lstm(x)
        out = self.hidden2tag(out)
        return out


state = [(1,2,3,4,5),(2,3,4,5,6),(3,4,5,6,7),(4,5,6,7,8),(5,6,7,8,9),(6,7,8,9,0)]

device = torch.device("cuda")
net = LSTM(5, 3).to(device)

state_v = torch.FloatTensor(state).to(device)

q_vals_v = net(state_v.view(1, state_v.shape[0], state_v.shape[1]))
_, action = int(torch.max(q_vals_v, dim=1).item())
Run Code Online (Sandbox Code Playgroud)

有谁知道如何解决这个问题?(摆脱作为元组的张量,以便它可以输入 LSTM 网络)

blu*_*nox 7

pytorch LSTM 返回一个元组。
所以你会得到这个错误,因为你的线性层self.hidden2tag无法处理这个元组。

所以改变:

out = self.lstm(x)
Run Code Online (Sandbox Code Playgroud)

out, states = self.lstm(x)
Run Code Online (Sandbox Code Playgroud)

这将通过拆分元组来解决您的错误,这out只是您的输出张量。

out然后存储隐藏状态,而states另一个元组包含最后的隐藏状态和单元格状态。

您也可以在这里查看:https :
//pytorch.org/docs/stable/nn.html#torch.nn.LSTM

您将在最后一行收到另一个错误,max()因为它也会返回一个元组。但这应该很容易修复并且是不同的错误:)