src_mask 和 src_key_padding_mask 的区别

Leo*_*Leo 6 transformer-model pytorch

我在理解变压器方面遇到了困难。一切都在一点一点地变得清晰,但让我头疼的一件事是 src_mask 和 src_key_padding_mask 之间的区别是什么,它在编码器层和解码器层的前向函数中作为参数传递。

https://pytorch.org/docs/master/_modules/torch/nn/modules/transformer.html#Transformer

Cha*_*ker 37

src_mask 和 src_key_padding_mask 之间的区别

\n

一般来说,要注意张量_mask与的使用之间的差异_key_padding_mask。\n在变压器内部,当注意力集中时,我们通常会得到一个平方中间张量,其中包含所有大小的比较\n [Tx, Tx](对于编码器的输入),[Ty, Ty](对于移位输出 - 解码器的输入之一)\n和[Ty, Tx](对于内存掩码 - 编码器/内存的输出和解码器/移位输出的输入之间的注意力)。

\n

所以我们知道这是转换器中每个掩码的用途\n(注意pytorch文档中的符号如下,其中Tx=S is the source sequence length\n(例如输入批次的最大值),\n Ty=T is the target sequence length(例如目标长度的最大值), \n B=N is the batch size,\n D=E is the feature number):

\n
    \n
  1. src_mask [Tx, Tx] = [S, S]\xe2\x80\x93 src 序列的附加掩码(可选)。\n这在执行atten_src + src_mask. 我不确定示例输入 - 请参阅 tgt_mask 示例\n但典型用途是添加,-inf以便可以根据需要以这种方式屏蔽 src_attention。\n如果提供了 ByteTensor,则不允许使用非零位置参加,零位置不变。\n如果提供了 BoolTensor,则 True 的位置不允许参加,而 False 值将保持不变。\n如果提供了 FloatTensor,它将添加到注意力权重中。

    \n
  2. \n
  3. tgt_mask [Ty, Ty] = [T, T]\xe2\x80\x93 tgt 序列的附加掩码(可选)。\n这在执行 时应用atten_tgt + tgt_mask。一个示例使用是对角线,以避免解码器作弊。\n因此 tgt 右移,第一个标记是嵌入 SOS/BOS 的序列标记的开始,因此第一个\n条目为零,而其余条目为零。具体示例见附录。\n如果提供了 ByteTensor,则非零位置不允许出现,零位置不变。\n如果提供了 BoolTensor,则 True 值不允许出现,False 值则不允许出现将保持不变。\n如果提供了 FloatTensor,它将添加到注意力权重中。

    \n
  4. \n
  5. memory_mask [Ty, Tx] = [T, S]\xe2\x80\x93 编码器输出的附加掩码(可选)。\n这在执行 时应用。\natten_memory + memory_mask不确定示例用途,但如前所述,添加将-inf一些注意力权重设置为零。\n如果 ByteTensor提供时,非零位置不允许出现,零位置不变。\n如果提供了 BoolTensor,则 True 位置不允许出现,而 False 值不变。\n如果提供了 FloatTensor,它将被添加到注意力权重中。

    \n
  6. \n
  7. src_key_padding_mask [B, Tx] = [N, S]\xe2\x80\x93 每批 src 键的 ByteTensor 掩码(可选)。\n由于您的 src 通常具有不同长度的序列,因此通常会删除您在末尾附加的\n填充向量。\n为此您指定批次中每个示例的每个序列的长度。\n请参阅附录中的具体示例。\n如果提供了 ByteTensor,则不允许出现非零位置,而零位置将保持不变。\n如果提供了 BoolTensor, True 的位置不允许参加,而 False 值将保持不变。\n如果提供了 FloatTensor,它将添加到注意力权重中。

    \n
  8. \n
  9. tgt_key_padding_mask [B, Ty] = [N, t]\xe2\x80\x93 每批 tgt 键的 ByteTensor 掩码(可选)。\n与之前相同。\n请参阅附录中的具体示例。\n如果提供了 ByteTensor,则不允许出现非零位置,而零位置将保持不变。\n如果提供了 BoolTensor,则不允许参与 True 值,而 False 值将保持不变。\n如果提供了 FloatTensor,它将添加到注意力权重中。

    \n
  10. \n
  11. memory_key_padding_mask [B, Tx] = [N, S]\xe2\x80\x93 每批内存键的 ByteTensor 掩码(可选)。\n与之前相同。\n请参阅附录中的具体示例。\n如果提供了 ByteTensor,则不允许出现非零位置,而零位置将保持不变。\n如果提供了 BoolTensor,则不允许参与 True 值,而 False 值将保持不变。\n如果提供了 FloatTensor,它将添加到注意力权重中。

    \n
  12. \n
\n

附录

\n

pytorch 教程的示例(https://pytorch.org/tutorials/beginner/translation_transformer.html):

\n

1 src_mask 示例

\n
    src_mask = torch.zeros((src_seq_len, src_seq_len), device=DEVICE).type(torch.bool)\n
Run Code Online (Sandbox Code Playgroud)\n

返回大小为布尔值的张量[Tx, Tx]

\n
tensor([[False, False, False,  ..., False, False, False],\n         ...,\n        [False, False, False,  ..., False, False, False]])\n
Run Code Online (Sandbox Code Playgroud)\n

2 tgt_mask示例

\n
    mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1)\n    mask = mask.transpose(0, 1).float()\n    mask = mask.masked_fill(mask == 0, float(\'-inf\'))\n    mask = mask.masked_fill(mask == 1, float(0.0))\n
Run Code Online (Sandbox Code Playgroud)\n

生成解码器输入的右移输出的对角线。

\n
tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,\n         -inf, -inf, -inf],\n        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,\n         -inf, -inf, -inf],\n        [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,\n         -inf, -inf, -inf],\n         ...,\n        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n         0., 0., -inf],\n        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n         0., 0., 0.]])\n
Run Code Online (Sandbox Code Playgroud)\n

通常,右移输出的开头有 BOS/SOS,教程只需在前面附加 BOS/SOS,然后用 修剪最后一个元素即可获得右移tgt_input = tgt[:-1, :]

\n

3 _填充

\n

填充只是为了掩盖末尾的填充。\nsrc 填充通常与内存填充相同。\ntgt 有它自己的序列,因此它有自己的填充。\n示例:

\n
    src_padding_mask = (src == PAD_IDX).transpose(0, 1)\n    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)\n    memory_padding_mask = src_padding_mask\n
Run Code Online (Sandbox Code Playgroud)\n

输出:

\n
tensor([[False, False, False,  ...,  True,  True,  True],\n        ...,\n        [False, False, False,  ...,  True,  True,  True]])\n
Run Code Online (Sandbox Code Playgroud)\n

请注意, aFalse表示那里没有填充令牌(因此是的,在变压器前向传递中使用该值),而 aTrue表示存在填充令牌(因此将其屏蔽,以便变压器前向传递不会受到影响)。

\n
\n

答案有点分散,但我发现只有这 3 个参考资料有用\n(诚实地说,单独的层文档/东西并不是很有用):

\n\n


San*_*jay 9

举一个小例子,考虑我想构建一个顺序推荐器,即给定用户在时间“t”之前购买的商品,预测“t+1”时的下一个商品

u1 - [i1, i2, i7]
u2 - [i2, i5]
u3 - [i6, i7, i1, i2]
Run Code Online (Sandbox Code Playgroud)

对于此任务,我可以使用转换器,通过在左侧填充 0 来使序列长度相等。

u1 - [0,  i1, i2, i7]
u2 - [0,  0,  i2, i5]
u3 - [i6, i7, i1, i2]
Run Code Online (Sandbox Code Playgroud)

我将使用 key_padding_mask 告诉 PyTorch 0 的 shd 被忽略。现在,考虑u3给定[i6]我想要预测的用户[i7],以及稍后给定[i6, i7]我想要预测的[i1]用户,即我想要因果注意力,这样注意力就不会窥视未来的元素。为此,我将使用 attn_mask。因此对于用户u3attn_mask 来说就像

[[True, False, False, False],
 [True, True , False, False],
 [True, True , True , False]
 [True, True , True , True ]]
Run Code Online (Sandbox Code Playgroud)


Was*_*mad 7

我必须说 PyTorch 实现有点令人困惑,因为它包含太多掩码参数。但我可以阐明您所指的两个掩模参数。该机制中同时使用了src_mask和。根据MultiheadAttention的文档:src_key_padding_maskMultiheadAttention

\n
\n

key_padding_mask \xe2\x80\x93 如果提供,则注意力将忽略键中指定的填充元素。

\n

attn_mask \xe2\x80\x93 防止关注某些位置的 2D 或 3D 掩码。

\n
\n

从论文中你知道,Attention就是你所需要的,MultiheadAttention在Encoder和Decoder中都使用了。然而,在Decoder中,MultiheadAttention有两种类型。一种是所谓的Masked MultiheadAttention,另一种是常规的MultiheadAttention。为了适应这两种技术,PyTorch 在其 MultiheadAttention 实现中使用了上述两个参数。

\n

所以,长话短说——

\n
    \n
  • attn_maskkey_padding_mask用于编码器MultiheadAttention和解码器Masked MultiheadAttention
  • \n
  • memory_mask 正如此处MultiheadAttention指出的,在解码器机制中使用。
  • \n
\n

研究MultiheadAttention的实现可能会对您有所帮助。

\n

正如您从此处此处看到的,首先src_mask用于阻止特定位置参加,然后key_padding_mask用于阻止参加填充令牌。

\n

笔记。答案根据@michael-jungo\'s 评论更新。

\n

  • *长话短说*下的两点是不正确的。首先,“attn_mask”和“key_padding_mask”用于自注意力(enc-enc 和 dec-dec)以及编码器-解码器注意力(enc-dec)。其次,PyTorch [在解码器中不使用`src_mask`,而是使用`memory_mask`](https://github.com/pytorch/pytorch/blob/ec5d579929b2c56418aacaec0874b92937d095a4/torch/nn/modules/transformer.py# L124-L127)(它们通常相同,但在 API 中是分开的)。`src_mask` 和 `src_key_padding_mask` 属于编码器的 self-attention。最后一句话总结得很好 (6认同)