Tes*_*a W 3 nlp tokenize huggingface-transformers
我正在使用 Huggingface Transformers 训练用于令牌分类的 XLM-RoBERTa 模型。我已经微调过的模型的最大标记长度是 166。我在训练数据中截断了较长的序列并填充了较短的序列。现在,在推理/预测期间,我想预测所有标记,即使是长度超过 166 的序列。但是,如果我正确阅读文档,溢出的标记就会被丢弃。那是对的吗?我不完全确定“return_overflowing_tokens”和 stride 参数的作用。它们可以用来将太长的序列分成两个或更多个较短的序列吗?
我已经尝试将文本数据分割成句子以具有更小的块,但其中一些仍然超过最大标记长度。如果溢出的令牌能够自动添加到附加序列中,那将是理想的。
假设您有以下字符串:
from transformers import XLMRobertaTokenizerFast
model_id = "xlm-roberta-large-finetuned-conll03-english"
t = XLMRobertaTokenizerFast.from_pretrained(model_id)
sample = "this is an example and context is important to retrieve meaningful contextualized token embeddings from the self attention mechanism of the transformer"
print(f"this string has {len(t.tokenize(sample))} tokens")
Run Code Online (Sandbox Code Playgroud)
输出:
this string has 32 tokens
Run Code Online (Sandbox Code Playgroud)
标记器max_length会截断文本,因此您的模型永远不会对截断的标记进行分类:
from transformers import XLMRobertaTokenizerFast
model_id = "xlm-roberta-large-finetuned-conll03-english"
t = XLMRobertaTokenizerFast.from_pretrained(model_id)
sample = "this is an example and context is important to retrieve meaningful contextualized token embeddings from the self attention mechanism of the transformer"
print(f"this string has {len(t.tokenize(sample))} tokens")
Run Code Online (Sandbox Code Playgroud)
输出:
10
['<s>', 'this', 'is', 'an', 'example', 'and', 'context', 'is', 'important', '</s>']
Run Code Online (Sandbox Code Playgroud)
要将截断的标记传递给模型,您可以使用return_overflowing_tokens:
this string has 32 tokens
Run Code Online (Sandbox Code Playgroud)
输出:
[10, 10, 10, 10]
<s> this is an example and context is important</s>
<s> to retrieve meaningful contextual</s>
<s>ized token embeddings from</s>
<s> the self attention mechanism of the transformer</s>
Run Code Online (Sandbox Code Playgroud)
您可能会注意到这里有一个问题。您的模型可能会面临为每个句子开头和结尾的标记生成有意义的嵌入(对于您的下游任务)的问题,因为它们由于硬切割方法而缺乏上下文。ized第三个序列的标记就是这个问题的一个很好的例子。
该问题的标准方法是滑动窗口方法,它为当前序列保留前一个序列的一些标记。您可以使用分词器的stride参数来控制滑动窗口:
encoded_max_length = t(sample, max_length=10, truncation=True).input_ids
print(len(encoded_max_length))
print(t.batch_decode(encoded_max_length))
Run Code Online (Sandbox Code Playgroud)
输出:
[10, 10, 10, 10, 10, 9]
<s> this is an example and context is important</s>
<s> context is important to retrieve meaning</s>
<s>trieve meaningful contextualized to</s>
<s>ualized token embeddings</s>
<s>embeddings from the self attention mechanism</s>
<s> self attention mechanism of the transformer</s>
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
1569 次 |
| 最近记录: |