我想保存 PyTorch 的torch.utils.data.dataloader.DataLoader实例,以便我可以从上次中断的地方继续训练(保留随机种子、状态和所有内容)。
您需要采样器的自定义实现。可以使用无麻烦的东西:https://gist.github.com/usamec/1b3b4dcbafad2d58faa71a9633eea6a5
您可以保存并恢复,如下所示:
sampler = ResumableRandomSampler(dataset)
loader = torch.utils.data.DataLoader(dataset, batch_size=2, sampler=sampler, pin_memory=True)
for x in loader:
print(x)
break
sampler2 = ResumableRandomSampler(dataset)
torch.save(sampler.get_state(), "test_samp.pth")
sampler2.set_state(torch.load("test_samp.pth"))
loader2 = torch.utils.data.DataLoader(dataset, batch_size=2, sampler=sampler2, pin_memory=True)
for x in loader2:
print(x)
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
9400 次 |
| 最近记录: |