小编Nep*_*ner的帖子

如何理解transformer中的masked multi-head attention

我目前正在研究transformer的代码,但我无法理解解码器的屏蔽多头。论文上说是为了不让你看到生成词,但是我无法理解生成词后的词如果没有生成,怎么能看到呢?

我尝试阅读变压器的代码(链接:https : //github.com/Kyubyong/transformer)。代码实现掩码如下所示。它使用下三角矩阵来屏蔽,我不明白为什么。

padding_num = -2 ** 32 + 1
diag_vals = tf.ones_like(inputs[0, :, :])  # (T_q, T_k)
tril = tf.linalg.LinearOperatorLowerTriangular(diag_vals).to_dense()  # (T_q, T_k)
masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(inputs)[0], 1, 1])  # (N, T_q, T_k)
paddings = tf.ones_like(masks) * padding_num
outputs = tf.where(tf.equal(masks, 0), paddings, inputs)
Run Code Online (Sandbox Code Playgroud)

transformer-model deep-learning tensorflow attention-model

16
推荐指数
1
解决办法
4431
查看次数