Transformer 模型中自注意力的计算复杂性

New*_*ton 12 nlp artificial-intelligence machine-learning neural-network deep-learning

我最近浏览了Google Research的Transformer论文,其中描述了自注意力层如何完全取代传统的基于 RNN 的序列编码层进行机器翻译。在论文的表 1 中,作者比较了不同序列编码层的计算复杂度,并指出(稍后)当序列长度n小于向量表示的维数时,自注意力层比 RNN 层快d

然而,如果我对计算的理解是正确的,自我注意层的复杂性似乎比声称的要低。让我们X成为自注意力层的输入。然后,X将具有形状,(n, d)因为n每个维度都有词向量(对应于行)d。计算自注意力的输出需要以下步骤(为了简单起见,考虑单头自注意力):

  1. 线性变换 的行X以计算 query Q、 keyK和 valueV矩阵,每个矩阵都具有 shape (n, d)。这是通过后乘以X3 个学习的形状矩阵(d, d)来实现的,计算复杂度为O(n d^2)
  2. 计算层输出,在论文的公式 1 中指定为SoftMax(Q Kt / sqrt(d)) V,其中在每一行上计算 softmax。计算Q Kt具有复杂性O(n^2 d),将结果与后乘V也具有复杂性O(n^2 d)

因此,该层的总复杂度为O(n^2 d + n d^2),比传统的 RNN 层差。在考虑适当的中间表示维度 ( dk, dv) 并最终乘以 head 的数量时,我也为多头注意力获得了相同的结果h

为什么作者在报告总计算复杂度时忽略了计算查询、键和值矩阵的成本?

我知道提议的层可以跨n位置完全并行化,但我相信表 1 无论如何都没有考虑到这一点。

igr*_*nis 7

首先,您的复杂性计算是正确的。那么,混乱的根源是什么?

当原始注意纸第一被引入,它并不需要计算QVK矩阵,作为值由RNNs的隐藏状态直接取出,因此,注意层的复杂性 O(n^2·d)

现在,要了解Table 1包含的内容,请记住大多数人是如何浏览论文的:他们阅读标题、摘要,然后查看图表和表格。只有当结果有趣时,他们才会更彻底地阅读论文。因此,该Attention is all you need论文的主要思想是在 seq2seq 设置中用注意力机制完全替换 RNN 层,因为 RNN 的训练速度非常慢。如果你Table 1在这个上下文中查看,你会发现它比较了 RNN、CNN 和注意力,并突出了这篇论文的动机:使用注意力应该比 RNN 和 CNN 更有利。它应该在 3 个方面具有优势:恒定的计算步骤量、恒定的运算量较低的通常 Google 设置的计算复杂度,其中n ~= 100d ~= 1000. 但正如任何想法一样,它击中了现实的硬墙。实际上,为了让这个伟大的想法发挥作用,他们必须添加位置编码,重新制定注意力并为其添加多个头。结果是 Transformer 架构,虽然其计算复杂度O(n^2·d + n·d^2)仍然比 RNN 快得多(在挂钟时间的意义上),并产生更好的结果。

所以你的问题的答案是作者提到的注意力层Table 1严格来说是注意力机制。这不是 Transformer 的复杂性。他们非常清楚他们模型的复杂性(我引用):

然而,可分离卷积 [6] 将复杂度显着降低到O(k·n·d + n·d^2). k = n然而,即使有,可分离卷积的复杂性也等于自我注意层和逐点前馈层的组合,这是我们在模型中采用的方法。


Sha*_*hai 2

严格来说,当仅考虑自注意力块(图2左,方​​程1)的复杂度时,x的投影qkv不包含在自注意力中。表 1 所示的复杂度仅针对自注意力层的核心,因此为O(n^2 d)