Use*_*non 10 machine-learning python-3.x pytorch mini-batch
我在collate_fn为 PyTorchDataLoader类编写自定义函数时遇到问题。我需要自定义函数,因为我的输入有不同的维度。
我目前正在尝试编写斯坦福 MURA 论文的基线实现。该数据集有一组标记的研究。一项研究可能包含多个图像。我创建了一个自定义Dataset类,使用torch.stack.
然后将堆叠张量作为输入提供给模型,并对输出列表进行平均以获得单个输出。此实现适用于DataLoaderwhen batch_size=1。但是,当我尝试将 设置batch_size为 8 时,就像原始论文中的情况一样,DataLoader失败了,因为它用于torch.stack堆叠批次并且我的批次中的输入具有可变尺寸(因为每个研究可以有多个图像)。
为了解决这个问题,我尝试实现我的自定义collate_fn函数。
def collate_fn(batch):
imgs = [item['images'] for item in batch]
targets = [item['label'] for item in batch]
targets = torch.LongTensor(targets)
return imgs, targets
Run Code Online (Sandbox Code Playgroud)
然后在我的训练纪元循环中,我像这样循环每个批次:
for image, label in zip(*batch):
label = label.type(torch.FloatTensor)
# wrap them in Variable
image = Variable(image).cuda()
label = Variable(label).cuda()
# forward
output = model(image)
output = torch.mean(output)
loss = criterion(output, label, phase)
Run Code Online (Sandbox Code Playgroud)
但是,这并没有给我任何改进的时间安排,并且仍然需要与批处理大小仅为 1 的时间一样长的时间。我也尝试将批处理大小设置为 32,但这也没有改善时间。
难道我做错了什么?有没有更好的方法来解决这个问题?
小智 1
非常有趣的问题!如果我理解正确的话(并且还检查了论文摘要),您有来自 14,863 项研究的 40,561 张图像,其中每项研究都由放射科医生手动标记为正常或异常。
我相信您遇到问题的原因是,例如,您创建了一个堆栈,
并且您尝试在训练期间使用 8 的批量大小,但当它开始研究 D 时会失败。
因此,我们是否有理由要对研究中的输出列表进行平均以适合单个标签?否则,我会简单地收集所有 40,561 个图像,为同一研究中的所有图像分配相同的标签(以便将 A 中的输出列表与 12 个标签的列表进行比较)。
因此,使用单个数据加载器,您可以在研究中进行洗牌(如果需要)并在训练期间使用所需的批量大小。
我发现这个问题已经存在了一段时间,我希望它对将来的人有所帮助:)
| 归档时间: |
|
| 查看次数: |
5959 次 |
| 最近记录: |