Leo*_*Leo 6 transformer-model pytorch
我在理解变压器方面遇到了困难。一切都在一点一点地变得清晰,但让我头疼的一件事是 src_mask 和 src_key_padding_mask 之间的区别是什么,它在编码器层和解码器层的前向函数中作为参数传递。
https://pytorch.org/docs/master/_modules/torch/nn/modules/transformer.html#Transformer
Cha*_*ker 37
一般来说,要注意张量_mask与的使用之间的差异_key_padding_mask。\n在变压器内部,当注意力集中时,我们通常会得到一个平方中间张量,其中包含所有大小的比较\n [Tx, Tx](对于编码器的输入),[Ty, Ty](对于移位输出 - 解码器的输入之一)\n和[Ty, Tx](对于内存掩码 - 编码器/内存的输出和解码器/移位输出的输入之间的注意力)。
所以我们知道这是转换器中每个掩码的用途\n(注意pytorch文档中的符号如下,其中Tx=S is the source sequence length\n(例如输入批次的最大值),\n Ty=T is the target sequence length(例如目标长度的最大值), \n B=N is the batch size,\n D=E is the feature number):
src_mask [Tx, Tx] = [S, S]\xe2\x80\x93 src 序列的附加掩码(可选)。\n这在执行atten_src + src_mask. 我不确定示例输入 - 请参阅 tgt_mask 示例\n但典型用途是添加,-inf以便可以根据需要以这种方式屏蔽 src_attention。\n如果提供了 ByteTensor,则不允许使用非零位置参加,零位置不变。\n如果提供了 BoolTensor,则 True 的位置不允许参加,而 False 值将保持不变。\n如果提供了 FloatTensor,它将添加到注意力权重中。
tgt_mask [Ty, Ty] = [T, T]\xe2\x80\x93 tgt 序列的附加掩码(可选)。\n这在执行 时应用atten_tgt + tgt_mask。一个示例使用是对角线,以避免解码器作弊。\n因此 tgt 右移,第一个标记是嵌入 SOS/BOS 的序列标记的开始,因此第一个\n条目为零,而其余条目为零。具体示例见附录。\n如果提供了 ByteTensor,则非零位置不允许出现,零位置不变。\n如果提供了 BoolTensor,则 True 值不允许出现,False 值则不允许出现将保持不变。\n如果提供了 FloatTensor,它将添加到注意力权重中。
memory_mask [Ty, Tx] = [T, S]\xe2\x80\x93 编码器输出的附加掩码(可选)。\n这在执行 时应用。\natten_memory + memory_mask不确定示例用途,但如前所述,添加将-inf一些注意力权重设置为零。\n如果 ByteTensor提供时,非零位置不允许出现,零位置不变。\n如果提供了 BoolTensor,则 True 位置不允许出现,而 False 值不变。\n如果提供了 FloatTensor,它将被添加到注意力权重中。
src_key_padding_mask [B, Tx] = [N, S]\xe2\x80\x93 每批 src 键的 ByteTensor 掩码(可选)。\n由于您的 src 通常具有不同长度的序列,因此通常会删除您在末尾附加的\n填充向量。\n为此您指定批次中每个示例的每个序列的长度。\n请参阅附录中的具体示例。\n如果提供了 ByteTensor,则不允许出现非零位置,而零位置将保持不变。\n如果提供了 BoolTensor, True 的位置不允许参加,而 False 值将保持不变。\n如果提供了 FloatTensor,它将添加到注意力权重中。
tgt_key_padding_mask [B, Ty] = [N, t]\xe2\x80\x93 每批 tgt 键的 ByteTensor 掩码(可选)。\n与之前相同。\n请参阅附录中的具体示例。\n如果提供了 ByteTensor,则不允许出现非零位置,而零位置将保持不变。\n如果提供了 BoolTensor,则不允许参与 True 值,而 False 值将保持不变。\n如果提供了 FloatTensor,它将添加到注意力权重中。
memory_key_padding_mask [B, Tx] = [N, S]\xe2\x80\x93 每批内存键的 ByteTensor 掩码(可选)。\n与之前相同。\n请参阅附录中的具体示例。\n如果提供了 ByteTensor,则不允许出现非零位置,而零位置将保持不变。\n如果提供了 BoolTensor,则不允许参与 True 值,而 False 值将保持不变。\n如果提供了 FloatTensor,它将添加到注意力权重中。
pytorch 教程的示例(https://pytorch.org/tutorials/beginner/translation_transformer.html):
\n src_mask = torch.zeros((src_seq_len, src_seq_len), device=DEVICE).type(torch.bool)\nRun Code Online (Sandbox Code Playgroud)\n返回大小为布尔值的张量[Tx, Tx]:
tensor([[False, False, False, ..., False, False, False],\n ...,\n [False, False, False, ..., False, False, False]])\nRun Code Online (Sandbox Code Playgroud)\n mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1)\n mask = mask.transpose(0, 1).float()\n mask = mask.masked_fill(mask == 0, float(\'-inf\'))\n mask = mask.masked_fill(mask == 1, float(0.0))\nRun Code Online (Sandbox Code Playgroud)\n生成解码器输入的右移输出的对角线。
\ntensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,\n -inf, -inf, -inf],\n [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,\n -inf, -inf, -inf],\n [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,\n -inf, -inf, -inf],\n ...,\n [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n 0., 0., -inf],\n [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n 0., 0., 0.]])\nRun Code Online (Sandbox Code Playgroud)\n通常,右移输出的开头有 BOS/SOS,教程只需在前面附加 BOS/SOS,然后用 修剪最后一个元素即可获得右移tgt_input = tgt[:-1, :]。
填充只是为了掩盖末尾的填充。\nsrc 填充通常与内存填充相同。\ntgt 有它自己的序列,因此它有自己的填充。\n示例:
\n src_padding_mask = (src == PAD_IDX).transpose(0, 1)\n tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)\n memory_padding_mask = src_padding_mask\nRun Code Online (Sandbox Code Playgroud)\n输出:
\ntensor([[False, False, False, ..., True, True, True],\n ...,\n [False, False, False, ..., True, True, True]])\nRun Code Online (Sandbox Code Playgroud)\n请注意, aFalse表示那里没有填充令牌(因此是的,在变压器前向传递中使用该值),而 aTrue表示存在填充令牌(因此将其屏蔽,以便变压器前向传递不会受到影响)。
答案有点分散,但我发现只有这 3 个参考资料有用\n(诚实地说,单独的层文档/东西并不是很有用):
\n举一个小例子,考虑我想构建一个顺序推荐器,即给定用户在时间“t”之前购买的商品,预测“t+1”时的下一个商品
u1 - [i1, i2, i7]
u2 - [i2, i5]
u3 - [i6, i7, i1, i2]
Run Code Online (Sandbox Code Playgroud)
对于此任务,我可以使用转换器,通过在左侧填充 0 来使序列长度相等。
u1 - [0, i1, i2, i7]
u2 - [0, 0, i2, i5]
u3 - [i6, i7, i1, i2]
Run Code Online (Sandbox Code Playgroud)
我将使用 key_padding_mask 告诉 PyTorch 0 的 shd 被忽略。现在,考虑u3给定[i6]我想要预测的用户[i7],以及稍后给定[i6, i7]我想要预测的[i1]用户,即我想要因果注意力,这样注意力就不会窥视未来的元素。为此,我将使用 attn_mask。因此对于用户u3attn_mask 来说就像
[[True, False, False, False],
[True, True , False, False],
[True, True , True , False]
[True, True , True , True ]]
Run Code Online (Sandbox Code Playgroud)
我必须说 PyTorch 实现有点令人困惑,因为它包含太多掩码参数。但我可以阐明您所指的两个掩模参数。该机制中同时使用了src_mask和。根据MultiheadAttention的文档:src_key_padding_maskMultiheadAttention
\n\nkey_padding_mask \xe2\x80\x93 如果提供,则注意力将忽略键中指定的填充元素。
\nattn_mask \xe2\x80\x93 防止关注某些位置的 2D 或 3D 掩码。
\n
从论文中你知道,Attention就是你所需要的,MultiheadAttention在Encoder和Decoder中都使用了。然而,在Decoder中,MultiheadAttention有两种类型。一种是所谓的Masked MultiheadAttention,另一种是常规的MultiheadAttention。为了适应这两种技术,PyTorch 在其 MultiheadAttention 实现中使用了上述两个参数。
所以,长话短说——
\nattn_mask并key_padding_mask用于编码器MultiheadAttention和解码器Masked MultiheadAttention。memory_mask 正如此处MultiheadAttention指出的,在解码器机制中使用。研究MultiheadAttention的实现可能会对您有所帮助。
\n正如您从此处和此处看到的,首先src_mask用于阻止特定位置参加,然后key_padding_mask用于阻止参加填充令牌。
笔记。答案根据@michael-jungo\'s 评论更新。
\n| 归档时间: |
|
| 查看次数: |
3581 次 |
| 最近记录: |