Leo*_*ang 6 distribute pytorch
我是 Pytorch DstributedDataParallel() 的新手,但我发现大多数教程在训练期间都保存了本地 rank 0模型。这意味着如果我得到 3 台机器,每台机器上都有 4 个 GPU,那么最终我会得到 3 个模型,这些模型可以从每台机器上节省下来。
例如在第 252 行的pytorch ImageNet教程中:
if not args.multiprocessing_distributed or (args.multiprocessing_distributed
and args.rank % ngpus_per_node == 0):
save_checkpoint({...})
Run Code Online (Sandbox Code Playgroud)
如果 ,他们会保存模型rank % ngpus_per_node == 0
。
据我所知,DistributedDataParallel() 会自动做所有减少后端的损失,不需要做任何进一步的工作,每个进程都可以基于此自动同步损失。每个流程上的所有模型只会在流程结束时略有不同。这意味着我们只需要保存一个模型就足够了。
那么为什么我们不只是将模型保存在 上rank == 0
,但是rank % ngpus_per_node == 0
呢?
如果我有多个模型,我应该使用哪个模型?
如果这是在分布式学习中保存模型的正确方法,我应该合并它们,使用其中之一,还是基于所有三个模型推断结果?
如果我错了,请告诉我。
如果我在任何地方错了,请纠正我
您所指的更改是2018
通过此提交引入的,并描述为:
在多处理模式下,只有一个进程会写入检查点
以前,这些都是在没有任何if
块的情况下保存的,因此每个 GPU 上的每个节点都会保存一个模型,这确实很浪费,并且很可能会在每个节点上多次覆盖保存的模型。
现在,我们谈论的是分布式多处理(可能有许多工人,每个工人可能有多个 GPU)。
args.rank
对于每个进程,因此通过以下行在脚本内修改:
args.rank = args.rank * ngpus_per_node + gpu
Run Code Online (Sandbox Code Playgroud)
其中有以下评论:
对于多进程分布式训练,rank需要是所有进程中的全局rank
因此args.rank
在所有节点的所有 GPU 中是唯一的 ID(或者看起来是这样)。
如果是这样,并且每个节点都有ngpus_per_node
(在此训练代码中,假设每个节点都具有与我收集的相同数量的 GPU),那么模型仅保存在每个节点上的一个(最后一个)GPU 上。在您使用3
机器和4
GPU 的示例中,您将获得3
保存的模型(希望我正确理解此代码,因为它非常复杂)。
如果您每个世界rank==0
只使用一个模型(其中世界将被定义为)将被保存。n_gpus * n_nodes
那么为什么我们不只是将模型保存在 rank == 0 上,而是保存 rank % ngpus_per_node == 0 呢?
我将从你的假设开始,即:
据我所知,DistributedDataParallel() 会自动做所有减少后端的损失,不需要做任何进一步的工作,每个进程都可以基于此自动同步损失。
准确地说,它与损失无关,而是gradient
根据文档(重点是我的)积累和对权重进行修正:
该容器通过在批处理维度中分块将输入拆分到指定的设备,从而并行化给定模块的应用程序 。该模块在每台机器和每台设备上复制,每个这样的副本处理输入的一部分。在向后传递期间,每个节点的梯度被平均。
因此,当使用一些权重创建模型时,它会在所有设备(每个节点的每个 GPU)上复制。现在每个 GPU 获得一部分输入(例如,对于总批次大小等于1024
,4
每个节点都有4
GPU,每个 GPU 将获得64
元素),计算前向传递,损失,通过.backward()
张量方法执行反向传播。现在所有的梯度都通过 all-gather 进行平均,在root
机器上优化参数,并将参数分发到所有节点,因此模块的状态在所有机器上始终相同。
注意:我不确定这种平均是如何发生的(我没有在文档中明确说明),但我假设这些是首先在 GPU 上平均,然后在所有节点上平均,因为我认为这将是最有效的.
现在,为什么要node
在这种情况下为每个保存模型?原则上您只能保存一个(因为所有模块都完全相同),但它有一些缺点:
如果我有多个模型,我应该使用哪个模型?
没关系,因为所有这些都将完全相同,因为通过优化器将相同的校正应用于具有相同初始权重的模型。
您可以使用这些方法来加载保存的.pth
模型:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
parallel_model = torch.nn.DataParallel(MyModelGoesHere())
parallel_model.load_state_dict(
torch.load("my_saved_model_state_dict.pth", map_location=str(device))
)
# DataParallel has model as an attribute
usable_model = parallel_model.model
Run Code Online (Sandbox Code Playgroud)
归档时间: |
|
查看次数: |
3853 次 |
最近记录: |