我目前正在使用 Huggingface 的 Trainer 类来训练 Distillbert,以解决使用自定义损失函数的回归问题。由于计算/意外错误的短暂性,我正在使用他们的检查点来恢复训练。
我面临的问题是,每次我通过方法model_path中的 Trainer 类从检查点恢复训练Trainer.train()时,我注意到该类会迭代数据加载器,直到达到检查点中保存的迭代计数(请参阅Trainer 类中与问题匹配的行)。
这通常可能不是问题,但由于我的数据加载器的整理功能的性质和数据集的大小,在没有任何训练的情况下迭代这样的持续时间是相当昂贵的,并且会减慢整体训练的速度。
我计划利用一个自定义采样器类,带有一个参数,以从给定位置恢复索引,但这对于给定问题来说似乎也很有效。
我可以尝试节省这些浪费的计算周期的替代方案是什么?