如何使用 T5 模型的输出来替换输入序列中的屏蔽标记

sca*_*der 3 python nlp huggingface-transformers generative-pretrained-transformer

我正在使用 Hugging Face Transformers 库中的 T5 模型,并且我有一个带有屏蔽标记的输入序列,我想将其替换为模型生成的输出。这是代码

from transformers import T5Tokenizer, T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")

input_data = "The <extra_id_0> walks in <extra_id_1> park"
input_ids = tokenizer(input_data, return_tensors="pt").input_ids

sequence_ids = model.generate(input_ids)
output_sequences = tokenizer.batch_decode(sequence_ids)
output_sequences
Run Code Online (Sandbox Code Playgroud)

此代码产生以下输出:

['<pad><extra_id_0> park offers<extra_id_1> the<extra_id_2> park.</s>']
Run Code Online (Sandbox Code Playgroud)

我想要做的是将输入序列中的屏蔽标记<extra_id_0>和替换<extra_id_1>为模型中相应的输出标记,以便最终输出为:

The park offers walks in the park.
Run Code Online (Sandbox Code Playgroud)

我希望有人可以帮助我编写代码来实现这一目标。

请注意,这是对应关系:

mask in input_data -> answer in output_sequences
<extra_id_0> -> <extra_id_0> park offers (so we extract 'park offers' only)
<extra_id_1> -> <extra_id_1> the  (so we extract 'the' only)
Run Code Online (Sandbox Code Playgroud)

小智 5

t5 模型将以 <extra_id 开头的标记视为潜在的掩码标记。正如文档中所写

\n

“每个哨兵标记代表该句子的唯一掩码标记,应以 <extra_id_0>、<extra_id_1>、\xe2\x80\xa6 开始,直到 <extra_id_99>”

\n

在输出中,您可以将 <extra_id_0> 和 <extra_id_1> 之间的文本视为 mask_0 的输出,将 <extra_id_1> 和 <extra_id_2> 之间的文本视为 mask 1 的输出,依此类推。

\n

要从生成的输出中提取此内容,您可以使用以下代码片段。这会将掩码数量作为输入,并返回一个字符串列表作为输出,其中每个元素表示为相应掩码预测的文本。

\n
def extract_text(text,num_masks=1):\n    list_of_text = []\n    for i in range(num_masks):\n        prev_id = \'<extra_id_\' + str(i) + \'>\'\n        curr_id = \'<extra_id_\' + str(i+1) + \'>\'\n        st_token_index = text.index(prev_id)\n        end_token_index = text.index(curr_id)\n        list_of_text.append(text[st_token_index+12:end_token_index])\n    return list_of_text\n
Run Code Online (Sandbox Code Playgroud)\n

另外,您应该注意,t5 并不是真正的掩码语言建模任务的最佳选择,如此处所讨论。像 BERT 这样的模型专门针对此类任务进行了训练,并且可以直接与 Huggingface 的填充掩模管道一起使用

\n
from transformers import pipeline\nnlp_fill = pipeline(\'fill-mask\')\n
Run Code Online (Sandbox Code Playgroud)\n