如何使用pytorch构建多任务DNN,例如超过100个任务?

use*_*010 4 regression classification deep-learning pytorch

下面是使用 pytorch 为两个回归任务构建 DNN 的示例代码。该forward函数返回两个输出 (x1, x2)。用于大量回归/分类任务的网络怎么样?例如,100 或 1000 个输出。对所有输出(例如,x1、x2、...、x100)进行硬编码绝对不是一个好主意。有一个简单的方法可以做到这一点吗?谢谢。

import torch
from torch import nn
import torch.nn.functional as F

class mynet(nn.Module):
    def __init__(self):
        super(mynet, self).__init__()
        self.lin1 = nn.Linear(5, 10)
        self.lin2 = nn.Linear(10, 3)
        self.lin3 = nn.Linear(10, 4)

    def forward(self, x):
        x = self.lin1(x)
        x1 = self.lin2(x)
        x2 = self.lin3(x)
        return x1, x2

if __name__ == '__main__':
    x = torch.randn(1000, 5)
    y1 = torch.randn(1000, 3)
    y2 = torch.randn(1000,  4)
    model = mynet()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
    for epoch in range(100):
        model.train()
        optimizer.zero_grad()
        out1, out2 = model(x)
        loss = 0.2 * F.mse_loss(out1, y1) + 0.8 * F.mse_loss(out2, y2)
        loss.backward()
        optimizer.step()
Run Code Online (Sandbox Code Playgroud)

Sha*_*hai 5

您可以(并且应该)使用诸如或 之类nn的容器来管理任意数量的子模块。nn.ModuleListnn.ModuleDict

例如(使用nn.ModuleList):

class MultiHeadNetwork(nn.Module):
    def __init__(self, list_with_number_of_outputs_of_each_head):
        super(MultiHeadNetwork, self).__init__()
        self.backbone = ...  # build the basic "backbone" on top of which all other heads come
        # all other "heads"
        self.heads = nn.ModuleList([])
        for nout in list_with_number_of_outputs_of_each_head:
            self.heads.append(nn.Sequential(
              nn.Linear(10, nout * 2),
              nn.ReLU(inplace=True),
              nn.Linear(nout * 2, nout)))

    def forward(self, x):
        common_features = self.backbone(x)  # compute the shared features
        outputs = []
        for head in self.heads:
            outputs.append(head(common_features))
        return outputs
Run Code Online (Sandbox Code Playgroud)

请注意,在此示例中,每个头比单层更复杂nn.Linear
不同“头”的数量(以及输出的数量)由参数 的长度决定list_with_number_of_outputs_of_each_head


重要提示:使用nn容器而不是简单的Python列表/字典来存储所有子模块至关重要。否则 pytorch 将难以管理所有子模块。
例如,参见这个答案这个问题这个