Huggingface 的“resume_from_checkpoint”有效吗?

Pen*_*uin 8 pytorch huggingface-transformers huggingface

我目前将我的教练设置为:

\n
training_args = TrainingArguments(\n    output_dir=f"./results_{model_checkpoint}",\n    evaluation_strategy="epoch",\n    learning_rate=5e-5,\n    per_device_train_batch_size=4,\n    per_device_eval_batch_size=4,\n    num_train_epochs=2,\n    weight_decay=0.01,\n    push_to_hub=True,\n    save_total_limit = 1,\n    resume_from_checkpoint=True,\n)\n\ntrainer = Trainer(\n    model=model,\n    args=training_args,\n    train_dataset=tokenized_qa["train"],\n    eval_dataset=tokenized_qa["validation"],\n    tokenizer=tokenizer,\n    data_collator=DataCollatorForMultipleChoice(tokenizer=tokenizer),\n    compute_metrics=compute_metrics\n)\n
Run Code Online (Sandbox Code Playgroud)\n

训练结束后,我的output_dirI 有训练师保存的几个文件:

\n
[\'README.md\',\n \'tokenizer.json\',\n \'training_args.bin\',\n \'.git\',\n \'.gitignore\',\n \'vocab.txt\',\n \'config.json\',\n \'checkpoint-5000\',\n \'pytorch_model.bin\',\n \'tokenizer_config.json\',\n \'special_tokens_map.json\',\n \'.gitattributes\']\n
Run Code Online (Sandbox Code Playgroud)\n

文档来看,似乎resume_from_checkpoint将从最后一个检查点继续训练模型:

\n

resume_from_checkpoint (str or bool, optional) \xe2\x80\x94 If a str, local path to a saved checkpoint as saved by a previous instance of Trainer. If a bool and equals True, load the last checkpoint in args.output_dir as saved by a previous instance of Trainer. If present, training will resume from the model/optimizer/scheduler states loaded here.

\n

但是当我打电话时,trainer.train()它似乎删除了最后一个检查点并开始一个新的检查点:

\n
Saving model checkpoint to ./results_distilbert-base-uncased/checkpoint-500\n...\nDeleting older checkpoint [results_distilbert-base-uncased/checkpoint-5000] due to args.save_total_limit\n
Run Code Online (Sandbox Code Playgroud)\n

它是否真的从最后一个检查点(即 5000)继续训练,并从 0 开始新检查点的计数(保存 500 步后的第一个检查点 - “checkpoint-500”),或者它只是不继续训练?我还没有找到测试它的方法,文档对此也不清楚。

\n

Way*_*uza 4

是的,它有效!当您调用 trainer.train() 时,您隐式地告诉它覆盖所有检查点并从头开始。您应该调用 trainer.train(resume_from_checkpoint=True) 或将resume_from_checkpoint 设置为指向检查点路径的字符串。