在 Huggingface Trainer 类中恢复训练时如何避免迭代 Dataloader?

Ram*_*ind 5 transformer-model pytorch huggingface-transformers

我目前正在使用 Huggingface 的 Trainer 类来训练 Distillbert,以解决使用自定义损失函数的回归问题。由于计算/意外错误的短暂性,我正在使用他们的检查点来恢复训练。

我面临的问题是,每次我通过方法model_path中的 Trainer 类从检查点恢复训练Trainer.train()时,我注意到该类会迭代数据加载器,直到达到检查点中保存的迭代计数(请参阅Trainer 类中与问题匹配的行)。

这通常可能不是问题,但由于我的数据加载器的整理功能的性质和数据集的大小,在没有任何训练的情况下迭代这样的持续时间是相当昂贵的,并且会减慢整体训练的速度。

我计划利用一个自定义采样器类,带有一个参数,以从给定位置恢复索引,但这对于给定问题来说似乎也很有效。

我可以尝试节省这些浪费的计算周期的替代方案是什么?

Ram*_*ind 0

看起来 Huggingface 通过使用TrainingArgumentsignore_data_skip中的参数提供了一个解决方案。

尽管您必须小心使用此标志。本质上就好像您从步骤 0 开始一个新纪元。但是您会将优化器/模型状态移动到从恢复点开始的任何状态。