Pytorch 相当于“tf.reverse_sequence”?

Jin*_*ich 3 python tensorflow pytorch

我想在填充序列上进行反向 LSTM,这需要在没有填充的情况下反转输入序列。

对于这样的批次(其中_代表填充):

a b c _ _ _
d e f g _ _
h i j k l m
Run Code Online (Sandbox Code Playgroud)

如果想得到:

c b a _ _ _
g f e d _ _
m l k j i h
Run Code Online (Sandbox Code Playgroud)

TensorFlow 有一个函数tf.reverse_sequence,它获取输入张量和批次中序列的长度并返回反向批次。在 Pytorch 中是否有一种简单的方法可以做到这一点?

den*_*ger 5

不幸的是,尽管已经提出了要求,但还没有直接的等价物。

我还查看了整个PackedSequence对象,但它没有.flip()定义任何操作。假设您已经有必要的数据来提供长度,正如您所建议的,您可以使用此函数来实现它:

def flipBatch(data, lengths):
    assert data.shape[0] == len(lengths), "Dimension Mismatch!"
    for i in range(data.shape[0]):
        data[i,:lengths[i]] = data[i,:lengths[i]].flip(dims=[0])

    return data
Run Code Online (Sandbox Code Playgroud)

不幸的是,这仅在您的序列是二维的(带有batch_size x sequence)时才有效,但您可以轻松地扩展它以满足您的特定输入要求。这已经或多或少涵盖了上面链接中的提案,但我将其更新为今天的标准。