7 python
我是 PyTorch、GAN 的新手,并且我在 Python 方面没有太多经验(尽管我是 C/C++ 程序员)。
\n我有一个简单的 DCGAN 教程代码,用于生成假图像,当我使用 \xe2\x80\x9cDATASETNAME = \xe2\x80\x98MNIST\xe2\x80\x99\xe2\x80\x9d 运行代码时,一切正常。但是,当我将数据集更改为 \xe2\x80\x98CIFAR10\xe2\x80\x99 时,程序会产生与 \xe2\x80\x9crunning_mean\xe2\x80\x9d 相关的错误。
\n代码如下
\nimport torch.nn as nn\n\ndef weights_init(module):\n\n if isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d):\n module.weight.detach().normal_(mean=0., std=0.02)\n elif isinstance(module, nn.BatchNorm2d):\n module.weight.detach().normal_(1., 0.02)\n module.bias.detach().zero_()\n else:\n pass\n\nclass View(nn.Module):\n\n def __init__(self, output_shape):\n super(View, self).__init__()\n self.output_shape = output_shape\n\n def forward(self, x):\n return x.view(x.shape[0], *self.output_shape)\n\n\nclass Generator(nn.Module):\n\n def __init__(self, dataset_name):\n super(Generator, self).__init__()\n act = nn.ReLU(inplace=True)\n norm = nn.BatchNorm2d\n\n if dataset_name == 'CIFAR10': # Output shape 3x32x32\n model = [nn.Linear(100, 512 * 4 * 4), View([512, 4, 4]), norm(512), act] # 4x4\n model += [nn.ConvTranspose2d(512, 256, 5, stride=2, padding=2, output_padding=1), norm(256), act] # 8x8\n model += [nn.ConvTranspose2d(256, 128, 5, stride=2, padding=2, output_padding=1), norm(128), act] # 16x16\n model += [nn.ConvTranspose2d(128, 3, 5, stride=2, padding=2, output_padding=1), nn.Tanh()] # 32x32\n\n elif dataset_name == 'LSUN': # Output shape 3x64x64\n model = [nn.Linear(100, 1024 * 4 * 4), View([1024, 4, 4]), norm(1024), act] # 4x4\n model += [nn.ConvTranspose2d(1024, 512, 5, stride=2, padding=2, output_padding=1), norm(512), act] # 8x8\n model += [nn.ConvTranspose2d(512, 256, 5, stride=2, padding=2, output_padding=1), norm(256), act] # 16x16\n model += [nn.ConvTranspose2d(256, 128, 5, stride=2, padding=2, output_padding=1), norm(128), act] # 32x32\n model += [nn.ConvTranspose2d(128, 3, 5, stride=2, padding=2, output_padding=1), nn.Tanh()] # 64x64\n\n elif dataset_name == 'MNIST': # Output shape 1x28x28\n model = [nn.Linear(100, 256 * 4 * 4), View([256, 4, 4]), norm(256), act] # 4x4\n model += [nn.ConvTranspose2d(256, 128, 5, stride=2, padding=2), norm(128), act] # 7x7\n model += [nn.ConvTranspose2d(128, 64, 5, stride=2, padding=2, output_padding=1), norm(64), act] # 14x14\n model += [nn.ConvTranspose2d(64, 1, 5, stride=2, padding=2, output_padding=1), nn.Tanh()] # 28x28\n\n else:\n raise NotImplementedError\n\n self.model = nn.Sequential(*model)\n\n def forward(self, x):\n return self.model(x)\n\n\nclass Discriminator(nn.Module):\n\n def __init__(self, dataset_name):\n super(Discriminator, self).__init__()\n act = nn.LeakyReLU(inplace=True, negative_slope=0.2)\n norm = nn.BatchNorm2d\n\n if dataset_name == 'CIFAR10': # Input shape 3x32x32\n model = [nn.Conv2d(3, 128, 5, stride=2, padding=2, bias=False), act] # 16x16\n model += [nn.Conv2d(128, 256, 5, stride=2, padding=2, bias=False), norm(128), act] # 8x8\n model += [nn.Conv2d(256, 512, 5, stride=2, padding=2, bias=False), norm(256), act] # 4x4\n model += [nn.Conv2d(512, 1, 4, stride=2, padding=2, bias=False), nn.Sigmoid()] # 1x1\n\n elif dataset_name == 'LSUN': # Input shape 3x64x64\n model = [nn.Conv2d(3, 128, 5, stride=2, padding=2, bias=False), act] # 128x32x32\n model += [nn.Conv2d(128, 256, 5, stride=2, padding=2, bias=False), norm(128), act] # 256x16x16\n model += [nn.Conv2d(256, 512, 5, stride=2, padding=2, bias=False), norm(256), act] # 512x8x8\n model += [nn.Conv2d(512, 1024, 5, stride=2, padding=2, bias=False), norm(512), act] # 1024x4x4\n model += [nn.Conv2d(1024, 1, 4), nn.Sigmoid()] # 1x1x1\n\n elif dataset_name == 'MNIST': # Input shape 1x28x28\n model = [nn.Conv2d(1, 64, 5, stride=2, padding=2, bias=False), act] # 14x14\n model += [nn.Conv2d(64, 128, 5, stride=2, padding=2, bias=False), norm(128), act] # 7x7\n model += [nn.Conv2d(128, 256, 5, stride=2, padding=2, bias=False), norm(256), act] # 4x4\n model += [nn.Conv2d(256, 1, 4, bias=False), nn.Sigmoid()] # 1x1\n\n else:\n raise NotImplementedError\n\n self.model = nn.Sequential(*model)\n\n def forward(self, x):\n return self.model(x)\n\n\nif __name__ == '__main__':\n\n import os\n from torchvision.transforms import Compose, Normalize, Resize, ToTensor\n from torch.utils.data import DataLoader\n #from models import Discriminator, Generator, weights_init\n import torch\n import torch.nn as nn\n import matplotlib.pyplot as plt\n from time import time\n from tqdm import tqdm\n from torchvision.utils import save_image\n os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n \n\n BETA1, BETA2 = 0.5, 0.99\n BATCH_SIZE = 16\n DATASET_NAME = 'CIFAR10'\n DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu:0')\n EPOCHS = 1\n ITER_REPORT = 10\n LATENT_DIM = 100\n LR = 2e-4\n N_D_STEP = 1\n ITER_DISPLAY = 500\n\n IMAGE_DIR = './GAN/checkpoints/'+DATASET_NAME+'/Image'\n MODEL_DIR = './GAN/checkpoints/'+DATASET_NAME+'/Model'\n \n if DATASET_NAME == 'CIFAR10':\n IMAGE_SIZE = 32\n OUT_CHANNEL = 3\n from torchvision.datasets import CIFAR10\n transforms = Compose([ToTensor(), Normalize(mean=[0.5], std=[0.5])])\n dataset = CIFAR10(root='./datasets', train=True, transform=transforms, download=True)\n elif DATASET_NAME == 'LSUN':\n IMAGE_SIZE = 64\n OUT_CHANNEL = 3\n from torchvision.datasets import LSUN\n transforms = Compose([Resize(64), ToTensor(), Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])\n dataset = LSUN(root='./datasets/LSUN', classes=['bedroom_train'], transform=transforms)\n elif DATASET_NAME == 'MNIST':\n IMAGE_SIZE = 28\n OUT_CHANNEL = 1\n from torchvision.datasets import MNIST\n transforms = Compose([ToTensor(), Normalize(mean=[0.5], std=[0.5])])\n dataset = MNIST(root='./datasets', train=True, transform=transforms, download=True)\n else:\n raise NotImplementedError\n\n data_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, num_workers=0, shuffle=True)\n\n D = Discriminator(DATASET_NAME).apply(weights_init).to(DEVICE)\n G = Generator(DATASET_NAME).apply(weights_init).to(DEVICE)\n print(D, G)\n criterion = nn.BCELoss()\n\n optim_D = torch.optim.Adam(D.parameters(), lr=LR, betas=(BETA1, BETA2))\n optim_G = torch.optim.Adam(G.parameters(), lr=LR, betas=(BETA1, BETA2))\n\n list_D_loss = list()\n list_G_loss = list()\n total_step = 0\n\n st = time()\n for epoch in range(EPOCHS):\n for data in tqdm(data_loader):\n total_step += 1\n real, label = data[0].to(DEVICE), data[1].to(DEVICE)\n z = torch.randn(BATCH_SIZE, LATENT_DIM).to(DEVICE)\n\n fake = G(z)\n\n real_score = D(real)\n fake_score = D(fake.detach())\n\n D_loss = 0.5 * (criterion(fake_score, torch.zeros_like(fake_score).to(DEVICE))\n + criterion(real_score, torch.ones_like(real_score).to(DEVICE)))\n optim_D.zero_grad()\n D_loss.backward()\n optim_D.step()\n list_D_loss.append(D_loss.detach().cpu().item())\n \n if total_step % ITER_DISPLAY == 0:\n #(BatchSize, Channel*ImageSize*ImageSize)-->(BatchSize, Channel, ImageSize, ImageSize)\n fake = fake.view(BATCH_SIZE, OUT_CHANNEL, IMAGE_SIZE, IMAGE_SIZE)\n real = real.view(BATCH_SIZE, OUT_CHANNEL, IMAGE_SIZE, IMAGE_SIZE)\n save_image(fake, IMAGE_DIR + '/{}_fake.png'.format(epoch + 1), nrow=4, normalize=True)\n save_image(real, IMAGE_DIR + '/{}_real.png'.format(epoch + 1), nrow=4, normalize=True)\n \n if total_step % N_D_STEP == 0:\n fake_score = D(fake)\n G_loss = criterion(fake_score, torch.ones_like(fake_score))\n optim_G.zero_grad()\n G_loss.backward()\n optim_G.step()\n list_G_loss.append(G_loss.detach().cpu().item())\n\n if total_step % ITER_REPORT == 0:\n print("Epoch: {}, D_loss: {:.{prec}} G_loss: {:.{prec}}"\n .format(epoch, D_loss.detach().cpu().item(), G_loss.detach().cpu().item(), prec=4))\n\n torch.save(D.state_dict(), '{}_D.pt'.format(DATASET_NAME))\n torch.save(G.state_dict(), '{}_G.pt'.format(DATASET_NAME))\n\n plt.figure()\n plt.plot(range(0, len(list_D_loss)), list_D_loss, linestyle='--', color='r', label='Discriminator loss')\n plt.plot(range(0, len(list_G_loss) * N_D_STEP, N_D_STEP), list_G_loss, linestyle='--', color='g',\n label='Generator loss')\n plt.xlabel('Iteration')\n plt.ylabel('Loss')\n plt.legend()\n plt.savefig('Loss.png')\n\n print(time() - st)\nRun Code Online (Sandbox Code Playgroud)\n该错误似乎来自 Discriminator .forward,如下所示:
\nRuntimeError Traceback (most recent call last)\nin \n71 fake = G(z)\n72\n?> 73 real_score = D(real)\n74 fake_score = D(fake.detach())\n75\n\nC:\\Anaconda3\\lib\\site-packages\\torch\\nn\\modules\\module.py in call(self, *input, **kwargs)\n491 result = self._slow_forward(*input, **kwargs)\n492 else:\n?> 493 result = self.forward(*input, **kwargs)\n494 for hook in self._forward_hooks.values():\n495 hook_result = hook(self, input, result)\n\nin forward(self, x)\n87\n88 def forward(self, x):\n?> 89 return self.model(x)\n\nC:\\Anaconda3\\lib\\site-packages\\torch\\nn\\modules\\module.py in call(self, *input, **kwargs)\n491 result = self._slow_forward(*input, **kwargs)\n492 else:\n?> 493 result = self.forward(*input, **kwargs)\n494 for hook in self._forward_hooks.values():\n495 hook_result = hook(self, input, result)\n\nC:\\Anaconda3\\lib\\site-packages\\torch\\nn\\modules\\container.py in forward(self, input)\n90 def forward(self, input):\n91 for module in self._modules.values():\n?> 92 input = module(input)\n93 return input\n94\n\nC:\\Anaconda3\\lib\\site-packages\\torch\\nn\\modules\\module.py in call(self, *input, **kwargs)\n491 result = self._slow_forward(*input, **kwargs)\n492 else:\n?> 493 result = self.forward(*input, **kwargs)\n494 for hook in self._forward_hooks.values():\n495 hook_result = hook(self, input, result)\n\nC:\\Anaconda3\\lib\\site-packages\\torch\\nn\\modules\\batchnorm.py in forward(self, input)\n81 input, self.running_mean, self.running_var, self.weight, self.bias,\n82 self.training or not self.track_running_stats,\n?> 83 exponential_average_factor, self.eps)\n84\n85 def extra_repr(self):\n\nC:\\Anaconda3\\lib\\site-packages\\torch\\nn\\functional.py in batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps)\n1695 return torch.batch_norm(\n1696 input, weight, bias, running_mean, running_var,\n-> 1697 training, momentum, eps, torch.backends.cudnn.enabled\n1698 )\n1699\n\nRuntimeError: running_mean should contain 256 elements not 128\nRun Code Online (Sandbox Code Playgroud)\n谁能告诉我这个错误是什么?它似乎来自模型中某些内容的大小设置,但我只能猜测\xe2\x80\x99。
\n先感谢您。
\n小智 3
线路
model += [nn.Conv2d(128, 256, 5, stride=2, padding=2, bias=False), norm(128), act] # 8x8
Run Code Online (Sandbox Code Playgroud)
这是批量归一化输入错误,应该是 256。
| 归档时间: |
|
| 查看次数: |
16953 次 |
| 最近记录: |