小编Ram*_*ind的帖子

在 Huggingface Trainer 类中恢复训练时如何避免迭代 Dataloader?

我目前正在使用 Huggingface 的 Trainer 类来训练 Distillbert,以解决使用自定义损失函数的回归问题。由于计算/意外错误的短暂性,我正在使用他们的检查点来恢复训练。

我面临的问题是,每次我通过方法model_path中的 Trainer 类从检查点恢复训练Trainer.train()时,我注意到该类会迭代数据加载器,直到达到检查点中保存的迭代计数(请参阅Trainer 类中与问题匹配的行)。

这通常可能不是问题,但由于我的数据加载器的整理功能的性质和数据集的大小,在没有任何训练的情况下迭代这样的持续时间是相当昂贵的,并且会减慢整体训练的速度。

我计划利用一个自定义采样器类,带有一个参数,以从给定位置恢复索引,但这对于给定问题来说似乎也很有效。

我可以尝试节省这些浪费的计算周期的替代方案是什么?

transformer-model pytorch huggingface-transformers

5
推荐指数
1
解决办法
908
查看次数