Pytorch Python 分布式多重处理:收集/连接不同长度/大小的张量数组

oms*_*gar 2 python distributed concatenation multiprocessing pytorch

如果多个 GPU 级别上有不同长度的张量数组,则默认all_gather方法不起作用,因为它要求长度相同。

例如,如果您有:

if gpu == 0:
    q = torch.tensor([1.5, 2.3], device=torch.device(gpu))
else:
    q = torch.tensor([5.3], device=torch.device(gpu))
Run Code Online (Sandbox Code Playgroud)

如果我需要收集这两个张量数组,如下所示:

all_q = [torch.tensor([1.5, 2.3], torch.tensor[5.3])

默认torch.all_gather不起作用,因为长度2, 1不同。

oms*_*gar 5

由于无法直接使用内置方法进行收集,因此我们需要按照以下步骤编写自定义函数:

  1. 用于dist.all_gather获取所有数组的大小。
  2. 找到最大尺寸。
  3. 使用零/常量将本地数组填充到最大大小。
  4. 用于dist.all_gather获取所有填充数组。
  5. 使用步骤 1 中找到的大小取消填充添加的零/常数。

下面的函数执行此操作:

def all_gather(q, ws, device):
    """
    Gathers tensor arrays of different lengths across multiple gpus
    
    Parameters
    ----------
        q : tensor array
        ws : world size
        device : current gpu device
        
    Returns
    -------
        all_q : list of gathered tensor arrays from all the gpus

    """
    local_size = torch.tensor(q.size(), device=device)
    all_sizes = [torch.zeros_like(local_size) for _ in range(ws)]
    dist.all_gather(all_sizes, local_size)
    max_size = max(all_sizes)

    size_diff = max_size.item() - local_size.item()
    if size_diff:
        padding = torch.zeros(size_diff, device=device, dtype=q.dtype)
        q = torch.cat((q, padding))

    all_qs_padded = [torch.zeros_like(q) for _ in range(ws)]
    dist.all_gather(all_qs_padded, q)
    all_qs = []
    for q, size in zip(all_qs_padded, all_sizes):
        all_qs.append(q[:size])
    return all_qs
Run Code Online (Sandbox Code Playgroud)

一旦我们能够执行上述操作,我们就可以根据torch.cat需要轻松地进一步连接成单个数组:

torch.cat(all_q)
[torch.tensor([1.5, 2.3, 5.3])
Run Code Online (Sandbox Code Playgroud)

改编自:github