PyTorch Lightning 子模型列表不会传输到 GPU

Jiv*_*van 2 python gpu pytorch pytorch-lightning

在 CPU 上使用 PyTorch Lightning 时,一切正常。然而,当使用 GPU 时,我得到一个RuntimeError: Expected all tensors to be on the same device.

问题似乎来自于使用未传递给 GPU 的子模型列表的模型:

class LambdaLayer(LightningModule):
    def __init__(self, fun):
        super(LambdaLayer, self).__init__()
        self.fun = fun

    def forward(self, x):
        return self.fun(x)

class TorchModel(LightningModule):
    def __init__(self):
        super(TorchModel, self).__init__()
        self.cat_layers = [TorchCatEmbedding(cat) for cat in columns_to_embed]
        self.num_layers = [LambdaLayer(lambda x: x[:, idx:idx+1]) for _, idx in numeric_columns]
        self.ffo = TorchFFO(len(self.num_layers) + sum([embed_dim(l) for l in self.cat_layers]), y.shape[1])
        self.softmax = torch.nn.Softmax(dim=1)

model = TorchModel()
trainer = Trainer(gpus=-1)
Run Code Online (Sandbox Code Playgroud)

运行前trainer(model)

>>> model.device
device(type='cpu')

>>> model.ffo.device
device(type='cpu')

>>> model.cat_layers[0].device
device(type='cpu')
Run Code Online (Sandbox Code Playgroud)

运行后trainer(model)

>>> model.device
device(type='cuda', index=0) # <---- correct

>>> model.ffo.device
device(type='cuda', index=0) # <---- correct

>>> model.cat_layers[0].device
device(type='cpu') # <---- still showing 'cpu'
Run Code Online (Sandbox Code Playgroud)

显然,PyTorch Lightning 无法将子模型列表传输到 GPU。如何继续将整个模型(包括子模型(cat_layersnum_layers)列表)传输到 GPU?

Kon*_*kos 6

列表中包含的子模块未注册,无法按原样进行转换。您需要使用ModuleList来代替,即:

...
from torch.nn import ModuleList
...

class TorchModel(LightningModule):
    def __init__(self):
        super(TorchModel, self).__init__()
        self.cat_layers = ModuleList([TorchCatEmbedding(cat) for cat in columns_to_embed])
        self.num_layers = ModuleList([LambdaLayer(lambda x: x[:, idx:idx+1]) for _, idx in numeric_columns])
        self.ffo = TorchFFO(len(self.num_layers) + sum([embed_dim(l) for l in self.cat_layers]), y.shape[1])
        self.softmax = torch.nn.Softmax(dim=1)
Run Code Online (Sandbox Code Playgroud)

编辑:我不确定 Lightning 等效项是什么,或者如果存在这样的等效项,另请参阅PyTorch Lightning - LightningModule for ModuleList / ModuleDict?