我正在实现一个变压器,并且一切正常,包括使用scaled_dot_product_attention
PyTorch 2.0 中的新功能的注意力。然而,我只会进行因果关注,因此使用该is_causal=True
标志来提高效率似乎是有意义的。只要 k、v 和 q 张量具有相同的大小,这也符合我的预期。
但我不确定如何在此模式下将过去的(缓存的)键/值传递给函数。如果 k、v 张量比 q 宽,我需要一个与 k/v 一样宽、与 q 一样高的矩形掩模,并屏蔽掉右上角的三角形。如果我自己构建这样一个掩码并将其传递给函数,一切都很好。我得到的行为类似于典型的因果关注,其中过去的标记被完全关注,而新的标记(有查询)被因果关注。
不过,根据文档,is_causal=True
这相当于使用以下构建的掩码:
attn_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
Run Code Online (Sandbox Code Playgroud)
其中 L 和 S 分别是查询长度和键/值长度。这使得除了左下三角形部分之外的所有部分都被屏蔽,该部分部分涉及过去的标记,而根本不涉及新的标记。这种因果模式是否不适合我的用例,或者我错过了什么?
假设我有以下张量:
q = torch.rand((1, n_heads, 3, head_dim))
k = torch.rand((1, n_heads, 6, head_dim))
v = torch.rand((1, n_heads, 6, head_dim))
attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True)
Run Code Online (Sandbox Code Playgroud)
其中 k 和 v 更宽,因为它们连接到先前推理过程的缓存结果上。scaled_dot_product_attention
应用以下掩码:
[[0, -inf, -inf, -inf, -inf, -inf]
[0, 0, …
Run Code Online (Sandbox Code Playgroud)