如何保存 PyTorch 的 DataLoader 实例?

edo*_*ost 4 pytorch

我想保存 PyTorch 的torch.utils.data.dataloader.DataLoader实例,以便我可以从上次中断的地方继续训练(保留随机种子、状态和所有内容)。

usa*_*mec 5

您需要采样器的自定义实现。可以使用无麻烦的东西:https://gist.github.com/usamec/1b3b4dcbafad2d58faa71a9633eea6a5

您可以保存并恢复,如下所示:

sampler = ResumableRandomSampler(dataset)
loader = torch.utils.data.DataLoader(dataset, batch_size=2, sampler=sampler, pin_memory=True)

for x in loader:
    print(x)
    break

sampler2 = ResumableRandomSampler(dataset)
torch.save(sampler.get_state(), "test_samp.pth")
sampler2.set_state(torch.load("test_samp.pth"))
loader2 = torch.utils.data.DataLoader(dataset, batch_size=2, sampler=sampler2, pin_memory=True)

for x in loader2:
    print(x)
Run Code Online (Sandbox Code Playgroud)