使用 T5 的句子嵌入

exe*_*nts 6 python nlp word-embedding pytorch

我想使用最先进的 LM T5 来获取句子嵌入向量。我找到了这个存储库https://github.com/UKPLab/sentence-transformers 据我所知,在 BERT 中我应该将第一个标记作为 [CLS] 标记,它将是句子嵌入。在此存储库中,我在 T5 模型上看到相同的行为:

cls_tokens = output_tokens[:, 0, :]  # CLS token is first token
Run Code Online (Sandbox Code Playgroud)

这种行为正确吗?我从 T5 中获取了编码器并用它编码了两个短语:

"I live in the kindergarden"
"Yes, I live in the kindergarden"
Run Code Online (Sandbox Code Playgroud)

它们之间的余弦相似度仅为“0.2420”。

我只需要了解句子嵌入是如何工作的——我应该训练网络来寻找相似性以获得正确的结果吗?或者我的基础预训练语言模型就足够了?

小智 7

last_hidden_state为了从 T5 获得句子嵌入,您需要从 T5 编码器输出中获取:

model.encoder(input_ids=s, attention_mask=attn, return_dict=True)
pooled_sentence = output.last_hidden_state # shape is [batch_size, seq_len, hidden_size]
# pooled_sentence will represent the embeddings for each word in the sentence
# you need to sum/average the pooled_sentence
pooled_sentence = torch.mean(pooled_sentence, dim=1)
Run Code Online (Sandbox Code Playgroud)

您现在有一个来自 T5 的句子嵌入

  • 为了进一步支持这个想法,在 [Sentence-T5](https://arxiv.org/abs/2108.08877) 中,他们表明平均令牌嵌入是 T5 的一个不错的选择。 (4认同)