pytorch中的批量波束搜索

Har*_*nan 6 python nlp beam-search deep-learning pytorch

我正在尝试在文本生成模型中实现波束搜索解码策略。这是我用来解码输出概率的函数。

def beam_search_decoder(data, k):
    sequences = [[list(), 0.0]]
    # walk over each step in sequence
    for row in data:
        all_candidates = list()
        for i in range(len(sequences)):
            seq, score = sequences[i]
            for j in range(len(row)):
                candidate = [seq + [j], score - torch.log(row[j])]
                all_candidates.append(candidate)
        # sort candidates by score
        ordered = sorted(all_candidates, key=lambda tup:tup[1])
        sequences = ordered[:k]
    return sequences
Run Code Online (Sandbox Code Playgroud)

现在你可以看到这个函数是在batch_size 1的情况下实现的。添加另一个用于批量大小的循环将使该算法成为可能O(n^4)。就像现在一样缓慢。有什么办法可以提高这个功能的速度吗?我的模型输出的大小通常(32, 150, 9907)遵循格式(batch_size, max_len, vocab_size)

小智 9

下面是我的实现,它可能比 for 循环实现快一点。

\n
import torch\n\n\ndef beam_search_decoder(post, k):\n    """Beam Search Decoder\n\n    Parameters:\n\n        post(Tensor) \xe2\x80\x93 the posterior of network.\n        k(int) \xe2\x80\x93 beam size of decoder.\n\n    Outputs:\n\n        indices(Tensor) \xe2\x80\x93 a beam of index sequence.\n        log_prob(Tensor) \xe2\x80\x93 a beam of log likelihood of sequence.\n\n    Shape:\n\n        post: (batch_size, seq_length, vocab_size).\n        indices: (batch_size, beam_size, seq_length).\n        log_prob: (batch_size, beam_size).\n\n    Examples:\n\n        >>> post = torch.softmax(torch.randn([32, 20, 1000]), -1)\n        >>> indices, log_prob = beam_search_decoder(post, 3)\n\n    """\n\n    batch_size, seq_length, _ = post.shape\n    log_post = post.log()\n    log_prob, indices = log_post[:, 0, :].topk(k, sorted=True)\n    indices = indices.unsqueeze(-1)\n    for i in range(1, seq_length):\n        log_prob = log_prob.unsqueeze(-1) + log_post[:, i, :].unsqueeze(1).repeat(1, k, 1)\n        log_prob, index = log_prob.view(batch_size, -1).topk(k, sorted=True)\n        indices = torch.cat([indices, index.unsqueeze(-1)], dim=-1)\n    return indices, log_prob\n
Run Code Online (Sandbox Code Playgroud)\n