Wei*_*han 4 deep-learning pytorch
我想做的是在自定义RNN类中使用DataParallel。
似乎我以错误的方式初始化了hidden_0。
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size, n_layers=1):
super(RNN, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.n_layers = n_layers
self.encoder = nn.Embedding(input_size, hidden_size)
self.gru = nn.GRU(hidden_size, hidden_size, n_layers,batch_first = True)
self.decoder = nn.Linear(hidden_size, output_size)
self.init_hidden(batch_size)
def forward(self, input):
input = self.encoder(input)
output, self.hidden = self.gru(input,self.hidden)
output = self.decoder(output.contiguous().view(-1,self.hidden_size))
output = output.contiguous().view(batch_size,num_steps,N_CHARACTERS)
#print (output.size())10,50,67
return output
def init_hidden(self,batch_size):
self.hidden = Variable(T.zeros(self.n_layers, batch_size, self.hidden_size).cuda())
Run Code Online (Sandbox Code Playgroud)
我以这种方式称呼网络:
decoder = T.nn.DataParallel(RNN(N_CHARACTERS, HIDDEN_SIZE, N_CHARACTERS), dim=1).cuda()
Run Code Online (Sandbox Code Playgroud)
然后开始训练:
for epoch in range(EPOCH_):
hidden = decoder.init_hidden()
Run Code Online (Sandbox Code Playgroud)
但是我得到了错误,并且我不知道如何解决它……
“ DataParallel”对象没有属性“ init_hidden”
谢谢你的帮助!
我所做的一个解决方法是:
self.model = model
# Since if the model is wrapped by the `DataParallel` class, you won't be able to access its attributes
# unless you write `model.module` which breaks the code compatibility. We use `model_attr_accessor` for attributes
# accessing only.
if isinstance(model, DataParallel):
self.model_attr_accessor = model.module
else:
self.model_attr_accessor = model
Run Code Online (Sandbox Code Playgroud)
这给了我这样的优势:当我这样做时self.model(input)
(即,当它被 包裹时DataParallel
),模型可以分布在我的 GPU 上;当我需要访问它的属性时,我就这样做self.model_attr_accessor.<<WHATEVER>>
。此外,这种设计为我提供了一种更模块化的方式来访问多个函数的属性,而无需if-statements
在所有函数中检查它是否被包装DataParallel
。
另一方面,如果您已经编写model.module.<<WHATEVER>>
并且模型没有被 包装DataParallel
,这将引发一个错误,指出您的模型没有module
属性。
然而,更紧凑的实现是创建一个DataParallel
像这样的自定义:
class _CustomDataParallel(nn.Module):
def __init__(self, model):
super(_CustomDataParallel, self).__init__()
self.model = nn.DataParallel(model).cuda()
def forward(self, *input):
return self.model(*input)
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.model.module, name)
Run Code Online (Sandbox Code Playgroud)
使用DataParallel
原始模块时,将在module
并行模块的属性中:
for epoch in range(EPOCH_):
hidden = decoder.module.init_hidden()
Run Code Online (Sandbox Code Playgroud)
归档时间: |
|
查看次数: |
1444 次 |
最近记录: |