mit*_*iee 3 neural-network python-3.x text-classification tensorflow bert-language-model
每个人!我正在阅读有关 Bert 的文章,并想利用其词嵌入进行文本分类。我遇到了这行代码:
pooled_output, sequence_output = self.bert_layer([input_word_ids, input_mask, segment_ids])
Run Code Online (Sandbox Code Playgroud)
进而:
clf_output = sequence_output[:, 0, :]
out = Dense(1, activation='sigmoid')(clf_output)
Run Code Online (Sandbox Code Playgroud)
但我无法理解合并输出的用途。序列输出不是包含了包括['CLS']的词嵌入在内的所有信息吗?如果是这样,为什么我们要汇总输出?
提前致谢!
序列输出是 BERT 模型最后一层输出的隐藏状态(嵌入)序列。它包括 [CLS] 令牌的嵌入。因此,对于句子“You are on Stackoverflow”,它给出了 5 个嵌入:四个单词中的每一个都有一个嵌入(假设单词“Stackoverflow”被标记为单个标记)以及 [CLS] 标记的嵌入。 池化输出是 [CLS] 标记(来自序列输出)的嵌入,由线性层和 Tanh 激活函数进一步处理。线性层权重是在预训练期间根据下一个句子预测(分类)目标进行训练的。更详细的内容请参考BERT原论文。
| 归档时间: |
|
| 查看次数: |
4854 次 |
| 最近记录: |