如何在代码中使用 PyTorch PackedSequence?

mik*_*305 5 machine-learning torch recurrent-neural-network pytorch

有人可以提供完整的工作代码(不是片段,而是在可变长度循环神经网络上运行的代码),说明您将如何使用 PyTorch 中的 PackedSequence 方法?

文档、github 或 Internet 中似乎没有任何此类示例。

https://github.com/pytorch/pytorch/releases/tag/v0.1.10

chi*_*gjn 6

不是最漂亮的一段代码,但这是我在浏览 PyTorch 论坛和文档后收集的供我个人使用的代码。当然可以有更好的方法来处理排序 - 恢复部分,但我选择它在网络本身中

class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_size, embedding_vectors=None, tune_embeddings=True, use_gru=True,
                 hidden_size=128, num_layers=1, bidrectional=True, dropout=0.6):
        super(Encoder, self).__init__()
        self.embed = nn.Embedding(vocab_size, embedding_size, padding_idx=0)
        self.embed.weight.requires_grad = tune_embeddings
        if embedding_vectors is not None:
            assert embedding_vectors.shape[0] == vocab_size and embedding_vectors.shape[1] == embedding_size
            self.embed.weight = nn.Parameter(torch.FloatTensor(embedding_vectors))
        cell = nn.GRU if use_gru else nn.LSTM
        self.rnn = cell(input_size=embedding_size, hidden_size=hidden_size, num_layers=num_layers,
                        batch_first=True, bidirectional=True, dropout=dropout)

    def forward(self, x, x_lengths):
        sorted_seq_lens, original_ordering = torch.sort(torch.LongTensor(x_lengths), dim=0, descending=True)
        ex = self.embed(x[original_ordering])
        pack = torch.nn.utils.rnn.pack_padded_sequence(ex, sorted_seq_lens.tolist(), batch_first=True)
        out, _ = self.rnn(pack)
        unpacked, unpacked_len = torch.nn.utils.rnn.pad_packed_sequence(out, batch_first=True)
        indices = Variable(torch.LongTensor(np.array(unpacked_len) - 1).view(-1, 1)
                                                                       .expand(unpacked.size(0), unpacked.size(2))
                                                                       .unsqueeze(1))
        last_encoded_states = unpacked.gather(dim=1, index=indices).squeeze(dim=1)
        scatter_indices = Variable(original_ordering.view(-1, 1).expand_as(last_encoded_states))
        encoded_reordered = last_encoded_states.clone().scatter_(dim=0, index=scatter_indices, src=last_encoded_states)
        return encoded_reordered
Run Code Online (Sandbox Code Playgroud)

  • 刚刚意识到我回答了一个 6 个月前的问题 =( (8认同)