在 PyTorch 中使用分布式数据并行 (DDP) 时,训练期间检查点的正确方法是什么?

Cha*_*ker 6 python distributed-computing neural-network deep-learning pytorch

我想要(正确且官方的无错误方式)执行以下操作:

\n
    \n
  1. 从检查点恢复以继续在多个 GPU 上进行训练
  2. \n
  3. 在使用多个 GPU 进行训练期间正确保存检查点
  4. \n
\n

为此,我的猜测如下:

\n
    \n
  1. 为了执行第 1 步,我们让所有进程从文件中加载检查点,然后调用DDP(mdl)每个进程。我假设检查点保存了一个ddp_mdl.module.state_dict().
  2. \n
  3. 要做2,只需检查谁的rank = 0并让其执行torch.save({\'model\': ddp_mdl.module.state_dict()})
  4. \n
\n

大概代码:

\n
def save_ckpt(rank, ddp_model, path):\n    if rank == 0:\n        state = {\'model\': ddp_model.module.state_dict(),\n             \'optimizer\': optimizer.state_dict(),\n            }\n        torch.save(state, path)\n\ndef load_ckpt(path, distributed, map_location=map_location=torch.device(\'cpu\')):\n    # loads to\n    checkpoint = torch.load(path, map_location=map_location)\n    model = Net(...)\n    optimizer = ...\n    model.load_state_dict(checkpoint[\'model\'])\n    optimizer.load_state_dict(checkpoint[\'optimizer\'])\n    if distributed:\n        model = DDP(model, device_ids=[gpu], find_unused_parameters=True)\n    return model\n
Run Code Online (Sandbox Code Playgroud)\n

它是否正确?

\n
\n

我问的原因之一是分布式代码可能会出现微妙的错误。我想确保这不会发生在我身上。当然,我想避免死锁,但如果它发生在我身上,那就很明显了(例如,如果所有进程以某种方式尝试同时打开同一个 ckpt 文件,则可能会发生这种情况。在这种情况下,我会以某种方式确保一次只有一个进程加载它,或者排名 0 只加载它,然后将其发送到其余进程)。

\n

我也问,因为官方文档对我来说没有意义。我将粘贴他们的代码和解释,因为链接有时会失效:

\n
\n

保存和加载检查点\n在训练期间使用 torch.save 和 torch.load 来检查点模块并从检查点恢复时,通常使用 torch.save 和 torch.load 。有关更多详细信息,请参阅保存和加载模型。使用DDP时,一种优化是仅将模型保存在一个进程中,然后将其加载到所有进程中,从而减少写入开销。这是正确的,因为所有过程都从相同的参数开始,并且梯度在向后传递中同步,因此优化器应该保持将参数设置为相同的值。如果您使用此优化,请确保在保存完成之前所有进程都不会开始加载。此外,在加载模块时,您需要提供适当的map_location参数以防止进程进入其他\xe2\x80\x99设备。如果缺少map_location,torch.load将首先将模块加载到CPU,然后将每个参数复制到保存它的位置,这将导致同一台机器上的所有进程使用同一组设备。有关更高级的故障恢复和弹性支持,请参阅 TorchElastic。

\n
\n
def demo_checkpoint(rank, world_size):\n    print(f"Running DDP checkpoint example on rank {rank}.")\n    setup(rank, world_size)\n\n    model = ToyModel().to(rank)\n    ddp_model = DDP(model, device_ids=[rank])\n\n    loss_fn = nn.MSELoss()\n    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)\n\n    CHECKPOINT_PATH = tempfile.gettempdir() + "/model.checkpoint"\n    if rank == 0:\n        # All processes should see same parameters as they all start from same\n        # random parameters and gradients are synchronized in backward passes.\n        # Therefore, saving it in one process is sufficient.\n        torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)\n\n    # Use a barrier() to make sure that process 1 loads the model after process\n    # 0 saves it.\n    dist.barrier()\n    # configure map_location properly\n    map_location = {\'cuda:%d\' % 0: \'cuda:%d\' % rank}\n    ddp_model.load_state_dict(\n        torch.load(CHECKPOINT_PATH, map_location=map_location))\n\n    optimizer.zero_grad()\n    outputs = ddp_model(torch.randn(20, 10))\n    labels = torch.randn(20, 5).to(rank)\n    loss_fn = nn.MSELoss()\n    loss_fn(outputs, labels).backward()\n    optimizer.step()\n\n    # Not necessary to use a dist.barrier() to guard the file deletion below\n    # as the AllReduce ops in the backward pass of DDP already served as\n    # a synchronization.\n\n    if rank == 0:\n        os.remove(CHECKPOINT_PATH)\n\n    cleanup()\n
Run Code Online (Sandbox Code Playgroud)\n
\n

有关的:

\n\n

Aer*_*ysS 4

我正在查看官方的ImageNet 示例,以下是他们的做法。首先,他们以DDP 模式创建模型:

model = ResNet50(...)
model = DDP(model,...)
Run Code Online (Sandbox Code Playgroud)

保存检查点,他们检查它是否是主进程,然后保存state_dict

import torch.distributed as dist

if dist.get_rank() == 0:  # check if main process, a simpler way compared to the link
    torch.save({'state_dict': model.state_dict(), ...},
                '/path/to/checkpoint.pth.tar')
Run Code Online (Sandbox Code Playgroud)

在加载过程中,他们像往常一样加载模型并将其置于 DDP 模式,而不需要检查排名:

checkpoint = torch.load('/path/to/checkpoint.pth.tar')
model = ResNet50(...).load_state_dict(checkpoint['state_dict'])
model = DDP(...)
Run Code Online (Sandbox Code Playgroud)

如果你想加载它但不是在 DDP 模式下,这有点棘手,因为由于某些原因他们用额外的后缀保存它module。正如此处解决的那样,您必须执行以下操作:

state_dict = torch.load(checkpoint['state_dict'])
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove 'module.' of DataParallel/DistributedDataParallel
    new_state_dict[name] = v

model.load_state_dict(new_state_dict)
Run Code Online (Sandbox Code Playgroud)