Jiv*_*van 2 python gpu pytorch pytorch-lightning
在 CPU 上使用 PyTorch Lightning 时,一切正常。然而,当使用 GPU 时,我得到一个RuntimeError: Expected all tensors to be on the same device.
问题似乎来自于使用未传递给 GPU 的子模型列表的模型:
class LambdaLayer(LightningModule):
def __init__(self, fun):
super(LambdaLayer, self).__init__()
self.fun = fun
def forward(self, x):
return self.fun(x)
class TorchModel(LightningModule):
def __init__(self):
super(TorchModel, self).__init__()
self.cat_layers = [TorchCatEmbedding(cat) for cat in columns_to_embed]
self.num_layers = [LambdaLayer(lambda x: x[:, idx:idx+1]) for _, idx in numeric_columns]
self.ffo = TorchFFO(len(self.num_layers) + sum([embed_dim(l) for l in self.cat_layers]), y.shape[1])
self.softmax = torch.nn.Softmax(dim=1)
model = TorchModel()
trainer = Trainer(gpus=-1)
Run Code Online (Sandbox Code Playgroud)
运行前trainer(model):
>>> model.device
device(type='cpu')
>>> model.ffo.device
device(type='cpu')
>>> model.cat_layers[0].device
device(type='cpu')
Run Code Online (Sandbox Code Playgroud)
运行后trainer(model):
>>> model.device
device(type='cuda', index=0) # <---- correct
>>> model.ffo.device
device(type='cuda', index=0) # <---- correct
>>> model.cat_layers[0].device
device(type='cpu') # <---- still showing 'cpu'
Run Code Online (Sandbox Code Playgroud)
显然,PyTorch Lightning 无法将子模型列表传输到 GPU。如何继续将整个模型(包括子模型(cat_layers和num_layers)列表)传输到 GPU?
列表中包含的子模块未注册,无法按原样进行转换。您需要使用ModuleList来代替,即:
...
from torch.nn import ModuleList
...
class TorchModel(LightningModule):
def __init__(self):
super(TorchModel, self).__init__()
self.cat_layers = ModuleList([TorchCatEmbedding(cat) for cat in columns_to_embed])
self.num_layers = ModuleList([LambdaLayer(lambda x: x[:, idx:idx+1]) for _, idx in numeric_columns])
self.ffo = TorchFFO(len(self.num_layers) + sum([embed_dim(l) for l in self.cat_layers]), y.shape[1])
self.softmax = torch.nn.Softmax(dim=1)
Run Code Online (Sandbox Code Playgroud)
编辑:我不确定 Lightning 等效项是什么,或者如果存在这样的等效项,另请参阅PyTorch Lightning - LightningModule for ModuleList / ModuleDict?
| 归档时间: |
|
| 查看次数: |
2677 次 |
| 最近记录: |