如何理解transformer中的masked multi-head attention

Nep*_*ner 16 transformer-model deep-learning tensorflow attention-model

我目前正在研究transformer的代码,但我无法理解解码器的屏蔽多头。论文上说是为了不让你看到生成词,但是我无法理解生成词后的词如果没有生成,怎么能看到呢?

我尝试阅读变压器的代码(链接:https : //github.com/Kyubyong/transformer)。代码实现掩码如下所示。它使用下三角矩阵来屏蔽,我不明白为什么。

padding_num = -2 ** 32 + 1
diag_vals = tf.ones_like(inputs[0, :, :])  # (T_q, T_k)
tril = tf.linalg.LinearOperatorLowerTriangular(diag_vals).to_dense()  # (T_q, T_k)
masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(inputs)[0], 1, 1])  # (N, T_q, T_k)
paddings = tf.ones_like(masks) * padding_num
outputs = tf.where(tf.equal(masks, 0), paddings, inputs)
Run Code Online (Sandbox Code Playgroud)

art*_*oby 20

阅读Transformer 论文后,我也有同样的问题。我在互联网上没有找到该问题的完整和详细的答案,因此我将尝试解释我对 Masked Multi-Head Attention 的理解。

简短的回答是 - 我们需要屏蔽以使训练平行。并行化很好,因为它可以让模型训练得更快。

这是一个解释这个想法的例子。假设我们训练将“我爱你”翻译成德语。编码器在并行模式下工作 - 它可以在恒定的步数(即步数不取决于输入序列的长度)内生成输入序列(“我爱你”)的向量表示。

假设编码器产生数字11, 12, 13作为输入序列的向量表示。实际上,这些向量会更长,但为简单起见,我们使用较短的向量。同样为简单起见,我们忽略了服务令牌,例如 - 序列的开头, - 序列的结尾等。

在训练期间我们知道翻译应该是“Ich liebe dich”(我们总是知道训练期间的预期输出)。假设“Ich liebe dich”词的预期向量表示为21, 22, 23

如果我们在序列模式下训练解码器,它看起来就像是循环神经网络的训练。将执行以下顺序步骤:

  • 顺序操作#1。输入:11, 12, 13
    • 试图预测21
    • 预测的输出不会完全是21,假设它会是21.1
  • 顺序操作#2。输入:11, 12, 13,也21.1作为之前的输出。
    • 试图预测22
    • 预测的输出不会完全是22,假设它会是22.3
  • 顺序操作#3。Input 11, 12, 13,也22.3作为之前的输出。
    • 试图预测23
    • 预测的输出不会完全是23,假设它会是23.5

这意味着我们需要进行 3 个顺序操作(在一般情况下 - 每个输入的顺序操作)。此外,我们将在每次下一次迭代中累积错误。此外,我们不使用注意力,因为我们只查看单个先前的输出。

由于我们实际上知道预期的输出,我们可以调整过程并使其并行。无需等待上一步的输出。

  • 并行操作#A。输入:11, 12, 13
    • 试图预测21
  • 并行操作#B。输入:11, 12, 13,还有21
    • 试图预测22
  • 并行操作#C。输入:11, 12, 13,还有21, 22
    • 试图预测23

该算法可以并行执行,并且不会累积错误。并且该算法使用注意力(即查看所有先前的输入),因此在进行预测时要考虑更多关于上下文的信息。

这就是我们需要屏蔽的地方。训练算法知道整个预期输出 ( 21, 22, 23)。它为每个并行操作隐藏(屏蔽)这个已知输出序列的一部分。

  • 当它执行#A - 它隐藏(屏蔽)整个输出。
  • 当它执行#B - 它隐藏第二和第三个输出。
  • 当它执行 #C 时 - 它隐藏了第三个输出。

屏蔽本身的实现如下(来自原始论文):

我们通过屏蔽(设置为 ??) softmax 输入中与非法连接相对应的所有值来在缩放点积注意力中实现这一点

注意:在推理(非训练)期间,解码器在顺序(非并行)模式下工作,因为它最初不知道输出序列。但它与 RNN 方法不同,因为 Transformer 推理仍然使用自注意力并查看所有先前的输出(但不仅仅是前一个)。

注 2:我在一些材料中看到,对于非翻译应用程序可以不同地使用遮罩。例如,对于语言建模,掩码可用于从输入句子中隐藏一些单词,模型将在训练期间尝试使用其他非掩码单词(即学习理解上下文)来预测它们。

  • 我推荐这篇[文章](https://towardsdatascience.com/illusterated-guide-to-transformers-step-by-step-explanation-f74876522bc0),你的解释和文章很有帮助。 (5认同)