Die*_*uel 0 summarization huggingface-transformers huggingface
我正在使用 DistilBART 进行抽象总结。该方法generate()使用起来非常简单。然而,它返回完整的、完成的摘要。我想要的是,在每一步中访问 logits,然后获取下一个候选单词列表,并根据我自己的标准进行选择。选择后,继续下一个单词,依此类推,直到生成 EOS 代币。
我知道我可以通过执行 来访问 logits model(**input).logits[:, -1, :],但这里的输入将是整个(编码)文本,那么这些 logits 到底对应什么?第一个生成的令牌?最后?
谢谢您的回答!
为了供将来参考,这里是如何完成的(注意:这是特定于编码器-解码器模型,如 BART):
1. 初始化
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# Load model
tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-xsum-1-1")
model = AutoModelForSeq2SeqLM.from_pretrained("sshleifer/distilbart-xsum-1-1")
text = "..."
# Tokenize text
batch = tokenizer(text, return_tensors="pt")
Run Code Online (Sandbox Code Playgroud)
2. 示例1:贪婪解码生成摘要(无缓存)
generated_sequence = torch.tensor([[tokenizer.sep_token_id]]) # initial token
# Generation loop
while True:
with torch.no_grad():
output = model(input_ids=batch["input_ids"], decoder_input_ids=generated_sequence)
next_token_logits = output.logits[:, -1, :]
next_token_scores = next_token_logits.softmax(dim=-1)
# Take token with highest probability
next_token = next_token_scores.argmax().unsqueeze(0).unsqueeze(0)
# Append token to generated sequence
generated_sequence = torch.cat((generated_sequence, next_token), dim=1)
# Stop if EOS token generated
if (generated_sequence.squeeze()[-1] == tokenizer.eos_token_id):
break
summary = tokenizer.batch_decode(generated_sequence, skip_special_tokens=True)
Run Code Online (Sandbox Code Playgroud)
3. 示例 2:使用top-k、top-p 采样和温度生成摘要(无缓存)
from transformers.generation_utils import top_k_top_p_filtering
temperature = 0.7
generated_sequence = torch.tensor([[tokenizer.sep_token_id]]) # initial token
# Generation loop
while True:
with torch.no_grad():
output = model(input_ids=batch["input_ids"], decoder_input_ids=generated_sequence)
logits = output.logits[:, -1, :] / temperature # apply temperature
filtered_logits = top_k_top_p_filtering(logits=logits, top_k=4, top_p=0.7)
probabilities = filtered_logits.softmax(dim=-1)
# Sample next token
next_token = torch.multinomial(probabilities, 1)
# Append token to generated sequence
generated_sequence = torch.cat((generated_sequence, next_token), dim=1)
# Stop if EOS token generated
if (generated_sequence.squeeze()[-1] == tokenizer.eos_token_id):
break
summary = tokenizer.batch_decode(generated_sequence, skip_special_tokens=True)
Run Code Online (Sandbox Code Playgroud)
(其他生成策略是类似的)。
4. 使用缓存
由于编码器的输入(即要摘要的文本)始终相同,因此我们可以对其进行缓存以大大加快生成速度。
generated_sequence = torch.tensor([[tokenizer.sep_token_id]]) # initial token
input_ids = batch["input_ids"]
past_key_values = None
with torch.no_grad():
output = model(
input_ids=input_ids,
decoder_input_ids=generated_sequence,
past_key_values=past_key_values
)
encoder_outputs=output.encoder_last_hidden_state
# Generation loop
while True:
# From here on, use cached attention
past_key_values = output.past_key_values
next_token_logits = output.logits[:, -1, :]
next_token_scores = next_token_logits.softmax(dim=-1)
next_token = next_token_scores.argmax().unsqueeze(0).unsqueeze(0) # greedy decoding
generated_sequence = torch.cat((generated_sequence, next_token), dim=1)
# Stop if EOS token generated
if (generated_sequence.squeeze()[-1] == tokenizer.eos_token_id):
break
with torch.no_grad():
output = model(
decoder_input_ids=torch.tensor([[generated_sequence.squeeze()[-1]]]),
past_key_values=past_key_values,
encoder_outputs=encoder_outputs
)
summary = tokenizer.batch_decode(generated_sequence, skip_special_tokens=True)
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
1345 次 |
| 最近记录: |