小编Zha*_*ren的帖子

如何在pytorch中使用LSTM进行分类?

我的代码如下:

class Mymodel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers, batch_size):
        super(Discriminator, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.batch_size = batch_size

        self.lstm = nn.LSTM(input_size, hidden_size)
        self.proj = nn.Linear(hidden_size, output_size)
        self.hidden = self.init_hidden()


    def init_hidden(self):
        return (Variable(torch.zeros(self.num_layers, self.batch_size, self.hidden_size)),
                Variable(torch.zeros(self.num_layers, self.batch_size, self.hidden_size)))

    def forward(self, x):
        lstm_out, self.hidden = self.lstm(x, self.hidden)
        output = self.proj(lstm_out)
        result = F.sigmoid(output)
        return result
Run Code Online (Sandbox Code Playgroud)

我想使用LSTM将句子分类为好(1)或坏(0)。使用此代码,我得到的结果是time_step * batch_size * 1,而不是0或1。如何编辑代码以获得分类结果?

pytorch

3
推荐指数
2
解决办法
8970
查看次数

标签 统计

pytorch ×1