类型错误:view() 最多接受 2 个参数(给出 3 个)

Din*_*yen 4 torch lstm pytorch

我尝试在 pytorch 中使用 view() 但我无法输入 3 个参数。我不知道为什么它一直出现这个错误?谁能帮我这个?

    def forward(self, input):
        lstm_out, self.hidden = self.lstm(input.view(len(input), self.batch_size, -1))
Run Code Online (Sandbox Code Playgroud)

Ser*_*nko 9

看起来你input是一个 numpy 数组,而不是火炬张量。您需要先转换它,例如input = torch.Tensor(input).