til*_*151 4 python gpu autoencoder deep-learning pytorch
我正在 PyTorch 中构建变分自动编码器 (VAE),但在编写与设备无关的代码时遇到问题。Autoencoder 是nn.Module编码器和解码器网络的子代,它们也是。网络的所有权重都可以通过调用从一个设备移动到另一个设备net.to(device)。
我遇到的问题是重新参数化技巧:
encoding = mu + noise * sigma
Run Code Online (Sandbox Code Playgroud)
噪声是一个mu与sigma和大小相同的张量,并保存为自动编码器模块的成员变量。它在构造函数中初始化,并在每个训练步骤就地重新采样。我这样做是为了避免每一步构建一个新的噪声张量并将其推送到所需的设备。此外,我想修复评估中的噪音。这是代码:
class VariationalGenerator(nn.Module):
def __init__(self, input_nc, output_nc):
super(VariationalGenerator, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
embedding_size = 128
self._train_noise = torch.randn(batch_size, embedding_size)
self._eval_noise = torch.randn(1, embedding_size)
self.noise = self._train_noise
# Create encoder
self.encoder = Encoder(input_nc, embedding_size)
# Create decoder
self.decoder = Decoder(output_nc, embedding_size)
def train(self, mode=True):
super(VariationalGenerator, self).train(mode)
self.noise = self._train_noise
def eval(self):
super(VariationalGenerator, self).eval()
self.noise = self._eval_noise
def forward(self, inputs):
# Calculate parameters of embedding space
mu, log_sigma = self.encoder.forward(inputs)
# Resample noise if training
if self.training:
self.noise.normal_()
# Reparametrize noise to embedding space
inputs = mu + self.noise * torch.exp(0.5 * log_sigma)
# Decode to image
inputs = self.decoder(inputs)
return inputs, mu, log_sigma
Run Code Online (Sandbox Code Playgroud)
当我现在将自动编码器移动到 GPU 时,net.to('cuda:0')由于未移动噪声张量,因此在转发时出现错误。
我不想在构造函数中添加设备参数,因为以后仍然无法将其移动到另一个设备。我还尝试将噪声包装成 ,nn.Parameter以便它受到 的影响net.to(),但这会导致优化器出错,因为噪声被标记为requires_grad=False。
任何人都有移动所有模块的解决方案net.to()?
tilman151 的第二种方法的更好版本可能是覆盖_apply,而不是to。这样net.cuda(),net.float()等都可以正常工作,因为这些都调用_apply而不是to(如源代码所示,它比您想象的要简单):
def _apply(self, fn):
super(VariationalGenerator, self)._apply(fn)
self._train_noise = fn(self._train_noise)
self._eval_noise = fn(self._eval_noise)
return self
Run Code Online (Sandbox Code Playgroud)
经过更多的反复试验,我找到了两种方法:
self._train_noise = torch.randn(batch_size, embedding_size)与self.register_buffer('_train_noise', torch.randn(batch_size, embedding_size)张量被添加到模块作为缓冲的噪声。这net.to(device)也会影响它。此外,张量现在是 state_dict 的一部分。覆盖net.to(device):使用它,噪音不会出现在 state_dict 之外。
def to(device):
new_self = super(VariationalGenerator, self).to(device)
new_self._train_noise = new_self._train_noise.to(device)
new_self._eval_noise = new_self._eval_noise.to(device)
return new_self
Run Code Online (Sandbox Code Playgroud)| 归档时间: |
|
| 查看次数: |
3699 次 |
| 最近记录: |