小编kru*_*bek的帖子

使用本地GPU的Google Colaboratory本地运行时

我正在使用Colaboratory和Pytorch来运行使用不寻常数据集的GAN,该数据集当前存储在本地计算机上。为了访问这些文件,我连接了本地运行时(按照https://research.google.com/colaboratory/local-runtimes.html)。但是,Colaboratory现在现在运行时会使用我自己的 GPU,而以前的运行并没有。我知道这是因为当前运行速度要慢得多,因为它们使用的是我的GTX 1060 6GB而不是Colab的Tesla K80。

我用这个检查了

torch.cuda.get_device_name(0)

当我在本地连接时,它将返回“ GeForce GTX 1060 6G”。即使选择了“编辑”->“笔记本设置”->“硬件加速器”->“ GPU” ,情况也是如此

但是,当我不在本地连接时,而是使用(默认)“连接到托管的运行时”选项时,

torch.cuda.get_device_name(0) 确实返回“ Tesla K80”。

我无法将数据集上传到云端硬盘,因为它是一个大型图像数据集,并且希望继续使用本地运行时。

如何同时使用本地运行时和Colab出色的Tesla K80?任何帮助将非常感激。

python machine-learning deep-learning pytorch google-colaboratory

7
推荐指数
1
解决办法
2035
查看次数

越来越大的正WGAN-GP损失

我正在研究在PyTorch中使用具有梯度损失的Wasserstein GAN,但始终会产生较大的正发电机损失,并随着时间的推移而增加。我大量借用了曹刚的实现,但是我正在使用此实现中使用的鉴别器和生成器损耗,因为Invalid gradient at index 0 - expected shape[] but got [1]如果尝试调用曹刚实现中使用.backward()onemoneargs,我会得到。

我正在使用增强的WikiArt数据集(> 400k 64x64图像)和CIFAR-10进行训练,并且获得了正常的WGAN(具有权重剪切功能)[即,它在25个时期后产生了可通过的图像],尽管事实上D和G损失在torch.mean(D_real)所有时期都徘徊在3 [我用等计算]。但是,在WGAN-GP版本中,发电机损耗在WikiArt和CIFAR-10数据集上都急剧增加,并且完全无法在WikiArt上产生噪声。

这是在CIFAR-10上经过25个纪元后的损失示例: WGAN-GP损失

我没有使用单面标签平滑等技巧,并且使用默认学习率0.001进行训练,使用Adam优化器,并且每次生成器更新时,对鉴别器进行5次训练。为什么会发生这种疯狂的丢失行为,为什么正常的减肥瘦身WGAN在WikiArt上仍然“起作用”,但WGANGP完全失败了?

无论G和D都是DCGAN还是使用改良的DCGAN时,都会发生这种情况,这是创意对抗网络Creative Adversarial Network),它要求D能够对图像进行分类,并且G生成模糊图像。

以下是我当前train方法的相关部分:

self.generator = Can64Generator(self.z_noise, self.channels, self.num_gen_filters).to(self.device)
self.discriminator =WCan64Discriminator(self.channels,self.y_dim, self.num_disc_filters).to(self.device)
style_criterion = nn.CrossEntropyLoss()

self.disc_optimizer = optim.Adam(self.discriminator.parameters(), lr=self.lr, betas=(self.beta1, 0.9))
self.gen_optimizer = optim.Adam(self.generator.parameters(), lr=self.lr, betas=(self.beta1, 0.9))


while i < len(dataloader):
            j = 0
            disc_loss_epoch = []
            gen_loss_epoch = [] …
Run Code Online (Sandbox Code Playgroud)

python machine-learning computer-vision deep-learning pytorch

2
推荐指数
1
解决办法
1585
查看次数