All*_*n-J 5 nlp pytorch huggingface-transformers
在 HuggingFace 代码的生成阶段: https://github.com/huggingface/transformers/blob/master/src/transformers/ Generation_utils.py#L88-L100
他们传递了 a decoder_start_token_id,我不确定他们为什么需要这个。在 BART 配置中,decoder_start_token_id实际上是2( https://huggingface.co/facebook/bart-base/blob/main/config.json ),它是句子结束标记</s>。
我尝试了一个简单的例子:
from transformers import *
import torch
model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
input_ids = torch.LongTensor([[0, 894, 213, 7, 334, 479, 2]])
res = model.generate(input_ids, num_beams=1, max_length=100)
print(res)
preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True).strip() for g in res]
print(preds)
Run Code Online (Sandbox Code Playgroud)
我得到的结果:
tensor([[ 2, 0, 894, 213, 7, 334, 479, 2]])
['He go to school.']
Run Code Online (Sandbox Code Playgroud)
虽然它并不影响最终的“标记化解码”结果。2但对我来说,我们生成的第一个令牌实际上是( ) ,这似乎很奇怪</s>。
您可以在编码器-解码器模型的代码中看到,解码器的输入标记相对于原始输入右移(请参阅函数shift_tokens_right)。这意味着第一个猜测的标记始终是 BOS(句子开头)。您可以检查您的示例中是否属于这种情况。
为了让解码器理解这一点,我们必须选择第一个令牌,该令牌始终后面跟着 BOS,那么它会是哪个呢?老板?显然不是,因为它后面必须跟有常规标记。填充令牌?也不是一个好的选择,因为它后面跟着另一个填充令牌或 EOS(句子结尾)。那么,EOS 呢?嗯,这是有道理的,因为它后面永远不会跟随训练集中的任何内容,因此不会出现下一个发生冲突的标记。此外,句子的开头跟在另一个句子的结尾不是很自然吗?