小编Leo*_*ang的帖子

pytorch - 如何从 DistributedDataParallel 学习中保存和加载模型

我是 Pytorch DstributedDataParallel() 的新手,但我发现大多数教程在训练期间都保存了本地 rank 0模型。这意味着如果我得到 3 台机器,每台机器上都有 4 个 GPU,那么最终我会得到 3 个模型,这些模型可以从每台机器上节省下来。

例如在第 252 行的pytorch ImageNet教程中:

if not args.multiprocessing_distributed or (args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            save_checkpoint({...})
Run Code Online (Sandbox Code Playgroud)

如果 ,他们会保存模型rank % ngpus_per_node == 0

据我所知,DistributedDataParallel() 会自动做所有减少后端的损失,不需要做任何进一步的工作,每个进程都可以基于此自动同步损失。每个流程上的所有模型只会在流程结束时略有不同。这意味着我们只需要保存一个模型就足够了。

那么为什么我们不只是将模型保存在 上rank == 0,但是rank % ngpus_per_node == 0呢?

如果我有多个模型,我应该使用哪个模型?

如果这是在分布式学习中保存模型的正确方法,我应该合并它们,使用其中之一,还是基于所有三个模型推断结果?

如果我错了,请告诉我。

distribute pytorch

6
推荐指数
1
解决办法
3853
查看次数

标签 统计

distribute ×1

pytorch ×1