小编tur*_*erp的帖子

scaled_dot_product_attention 如何与因果 LM 中的缓存键/值一起使用?

我正在实现一个变压器,并且一切正常,包括使用scaled_dot_product_attentionPyTorch 2.0 中的新功能的注意力。然而,我只会进行因果关注,因此使用该is_causal=True标志来提高效率似乎是有意义的。只要 k、v 和 q 张量具有相同的大小,这也符合我的预期。

但我不确定如何在此模式下将过去的(缓存的)键/值传递给函数。如果 k、v 张量比 q 宽,我需要一个与 k/v 一样宽、与 q 一样高的矩形掩模,并屏蔽掉右上角的三角形。如果我自己构建这样一个掩码并将其传递给函数,一切都很好。我得到的行为类似于典型的因果关注,其中过去的标记被完全关注,而新的标记(有查询)被因果关注。

不过,根据文档is_causal=True这相当于使用以下构建的掩码:

attn_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
Run Code Online (Sandbox Code Playgroud)

其中 L 和 S 分别是查询长度和键/值长度。这使得除了左下三角形部分之外的所有部分都被屏蔽,该部分部分涉及过去的标记,而根本不涉及新的标记。这种因果模式是否不适合我的用例,或者我错过了什么?

假设我有以下张量:

q = torch.rand((1, n_heads, 3, head_dim))
k = torch.rand((1, n_heads, 6, head_dim))
v = torch.rand((1, n_heads, 6, head_dim))

attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True)
Run Code Online (Sandbox Code Playgroud)

其中 k 和 v 更宽,因为它们连接到先前推理过程的缓存结果上。scaled_dot_product_attention应用以下掩码:

[[0, -inf, -inf, -inf, -inf, -inf]
 [0,    0, …
Run Code Online (Sandbox Code Playgroud)

language-model pytorch self-attention

5
推荐指数
0
解决办法
1125
查看次数

标签 统计

language-model ×1

pytorch ×1

self-attention ×1