BERT模型的参数个数是如何计算的?

Ech*_*che 7 nlp neural-network bert-language-model

Devlin & Co. 的论文《BERT: Pre-training of Deep Bi Direction Transformers for Language Understanding》针对基础模型大小 110M 参数(即 L=12、H=768、A=12)进行了计算,其中 L = 层数, H = 隐藏大小,A = 自注意力操作的数量。据我所知,神经网络中的参数通常是层之间“权重和偏差”的计数。那么如何根据给定的信息计算出这个值呢?12 768 768*12?

小智 25

Transformer 编码器-解码器架构\nBERT 模型仅包含 Transformer 架构的编码器块。让我们看看 BERT 编码器块的各个元素,以可视化数字权重矩阵以及偏差向量。给定的配置 L = 12 意味着将有 12 层 self-attention,H = 768 意味着单个 token 的嵌入维度将为 768 维,A = 12 意味着一层 self-attention 将有 12 个注意力头。编码器块执行以下操作序列:

\n
    \n
  1. 输入将是作为 S * d 维度矩阵的标记序列。其中 s 是序列长度,d 是嵌入维度。生成的输入序列将是令牌嵌入、令牌类型嵌入以及作为每个令牌的 d 维向量的位置嵌入的总和。在 BERT 模型中,第一组参数是词汇嵌入。BERT 使用具有 30522 个标记的WordPiece[ 2 ] 嵌入。每个令牌有 768 个维度。

    \n
  2. \n
  3. 嵌入层归一化。一个权重矩阵和一个偏置向量。

    \n
  4. \n
  5. 多头自注意力。将有 h 个头,对于每个头将有三个矩阵,分别对应于查询矩阵、键矩阵和值矩阵。这些矩阵的第一个维度将是嵌入维度,第二个维度将是嵌入维度除以注意力头的数量。除此之外,还会有一个矩阵来将注意力头生成的串联值转换为最终的令牌表示。

    \n
  6. \n
  7. 剩余连接和层标准化。一个权重矩阵和一个偏置向量。

    \n
  8. \n
  9. 位置式前馈网络将具有一个隐藏层,该隐藏层对应于两个权重矩阵和两个偏差向量。论文中提到,隐藏层中的单元数量将是嵌入维度的四倍。

    \n
  10. \n
  11. 剩余连接和层标准化。一个权重矩阵和一个偏置向量。

    \n
  12. \n
\n

让我们通过将正确的维度与 BERT 基础模型的权重矩阵和偏差向量相关联来计算参数的实际数量。

\n

嵌入矩阵:

\n
    \n
  • 词嵌入矩阵大小[词汇大小,嵌入维度] = [30522, 768] = 23440896
  • \n
  • 位置嵌入矩阵大小,[最大序列长度,嵌入维度] = [512, 768] = 393216
  • \n
  • 令牌类型嵌入矩阵大小 [2, 768] = 1536
  • \n
  • 嵌入层归一化、权重和偏差 [768] + [768] = 1536
  • \n
  • 嵌入参数总数 = \xe2\x89\x88
  • \n
\n

注意头:

\n
    \n
  • 查询权重矩阵大小 [768, 64] = 49152 且偏差 [768] = 768

    \n
  • \n
  • 键权重矩阵大小 [768, 64] = 49152 且偏差 [768] = 768

    \n
  • \n
  • 值权重矩阵大小 [768, 64] = 49152 且偏差 [768] = 768

    \n
  • \n
  • 12 个头的一层注意力的总参数 = 12\xe2\x88\x97(3 \xe2\x88\x97(49152+768)) = 1797120

    \n
  • \n
  • 头连接后投影的密集权重 [768, 768] = 589824 和偏差 [768] = 768, (589824+768 = 590592)

    \n
  • \n
  • 层归一化权重和偏差 [768], [768] = 1536

    \n
  • \n
  • 位置明智的前馈网络权重矩阵和偏差 [3072, 768] = 2359296,[3072] = 3072 和 [768, 3072 ] = 2359296,[768] = 768,(2359296+3072+ 2359296+768 = 4722432)

    \n
  • \n
  • 层归一化权重和偏差 [768], [768] = 1536

    \n
  • \n
  • 一个完整注意力层的总参数 (1797120 + 590592 + 1536 + 4722432 + 1536 = 7113216 \xe2\x89\x88 7 )

    \n
  • \n
  • 12 层注意力的总参数 ( \xe2\x88\x97 = \xe2\x89\x88 )

    \n
  • \n
\n

BERT Encoder的输出层:

\n
    \n
  • 密集权重矩阵和偏差 [768, 768] = 589824, [768] = 768, (589824 + 768 = 590592)
  • \n
\n

ase 中的总参数 = + + = \xe2\x89\x88

\n