小编Min*_*Ngo的帖子

HuggingFace - 当我从检查点加载时 model.generate() 非常慢

我正在尝试使用 Donut 模型(在 HuggingFace 库中提供)使用我的自定义数据集(格式类似于 RVL-CDIP)进行文档分类。当我训练模型并model.generate()在训练循环中运行模型推理(使用方法)进行模型评估时,这是正常的(每张图像的推理大约需要0.2s)。

但是,如果训练后,我使用该方法将模型保存到检查点save_pretrained,然后使用该from_pretrained方法加载检查点,则model.generate()运行速度极慢(6s ~ 7s)。

这是我用于推理的代码(训练循环中的推理代码完全相同):

model = VisionEncoderDecoderModel.from_pretrained(CKPT_PATH, config=config)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)

accs = []
model.eval()
for i, sample in tqdm(enumerate(val_ds), total=len(val_ds)):
    pixel_values = sample["pixel_values"]
    pixel_values = torch.unsqueeze(pixel_values, 0)
    pixel_values = pixel_values.to(device)

    start = time.time()
    task_prompt = "<s_fci>"
    decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
    decoder_input_ids = decoder_input_ids.to(device)
    print(f"Tokenize time: {time.time() - start:.4f}s")

    start = time.time()
    outputs = model.generate(
        pixel_values,
        decoder_input_ids=decoder_input_ids,
        max_length=model.decoder.config.max_position_embeddings,
        early_stopping=True,
        pad_token_id=processor.tokenizer.pad_token_id, …
Run Code Online (Sandbox Code Playgroud)

python pytorch huggingface-transformers

5
推荐指数
0
解决办法
1458
查看次数

标签 统计

huggingface-transformers ×1

python ×1

pytorch ×1