MultiHeadAttnetion 中的 att_mask 和 key_padding_mask 有什么区别

one*_*one 8 python transformer-model deep-learning attention-model pytorch

pytorchatt_maskkey_padding_maskin的区别MultiHeadAttnetion是什么:

key_padding_mask – 如果提供,键中指定的填充元素将被注意力忽略。当给定一个二元掩码并且值为 True 时,注意力层上的相应值将被忽略。当给定字节掩码且值为非零时,将忽略注意力层上的相应值

attn_mask – 2D 或 3D 掩码,防止对某些位置的注意。将为所有批次广播 2D 掩码,而 3D 掩码允许为每个批次的条目指定不同的掩码。

提前致谢。

Jin*_*ich 10

key_padding_mask用于屏蔽掉被填充的位置,即,输入序列的结束之后。这始终特定于输入批次,并取决于批次中的序列与最长的序列相比有多长。它是一个形状为批量大小×输入长度的二维张量。

另一方面,attn_mask说明哪些键值对是有效的。在 Transformer 解码器中,三角形掩码用于模拟推理时间并防止关注“未来”位置。这是att_mask通常使用的。如果是二维张量,则形状为input length × input length。您还可以拥有一个特定于批次中每个项目的掩码。在这种情况下,您可以使用形状为(batch size × num head) × input length × input length的 3D 张量。(因此,理论上,您可以key_padding_mask使用 3D进行模拟att_mask。)