多头注意力:Q、K、V 线性变换的正确实现

Ger*_*ens 2 nlp neural-network attention-model pytorch bert-language-model

我现在正在 Pytorch 中实现多头自注意力。我查看了几个实现,它们似乎有点错误,或者至少我不确定为什么要这样做。他们通常只会应用一次线性投影

    self.query_projection = nn.Linear(input_dim, output_dim)
    self.key_projection = nn.Linear(input_dim, output_dim)
    self.value_projection = nn.Linear(input_dim, output_dim)
Run Code Online (Sandbox Code Playgroud)

然后他们经常将投影重塑为

    query_heads = query_projected.view(batch_size, query_lenght, head_count, head_dimension).transpose(1,2)
    key_heads = key_projected.view(batch_size, key_len, head_count, head_dimension).transpose(1, 2)  # (batch_size, heads_count, key_len, d_head)
    value_heads = value_projected.view(batch_size, value_len, head_count, head_dimension).transpose(1, 2)  # (batch_size, heads_count, value_len, d_head)

    attention_weights = scaled_dot_product(query_heads, key_heads) 
Run Code Online (Sandbox Code Playgroud)

根据此代码,每个头将处理预计查询的一部分。然而,最初的论文说我们需要为编码器中的每个头有一个不同的线性投影。

这个显示的实现正确吗?

hkc*_*rex 5

它们是等价的。

理论上(以及在论文写作中),更容易将它们视为单独的线性投影。假设你有 8 个头,每个头都有一个M->N投影,那么就有一个8 N by M矩阵。

M->8N但在实现中,通过矩阵进行变换会更快8N by M

可以连接第一个公式中的矩阵以获得第二个公式中的矩阵。