使用 Pytorch 数据加载器加载特定样本的简单方法

Flo*_*sch 6 python machine-learning deep-learning pytorch

我目前正在训练一个 3D CNN,用于具有相对稀疏标签的二元分类(标签数据中约 1% 的体素对应于目标类)。

为了在训练期间执行基本的健全性检查(例如,网络是否完全学习?),向网络展示一个小的、精心挑选的训练示例子集,其目标类别标签的比例高于平均水平。

正如 Pytorch 文档所建议的那样,我实现了自己的dataset类(继承自torch.utils.data.Dataset),它通过它的__get_item__方法向torch.utils.data.DataLoader.

在我发现的pytorch 教程中DataLoader用作迭代器来生成训练循环,如下所示:

for i, data in enumerate(self.dataloader):

    # Get training data
    inputs, labels = data

    # Train the network
    # [...]
Run Code Online (Sandbox Code Playgroud)

我现在想知道的是是否存在一种简单的方法来加载单个或几个特定的​​训练示例(使用Dataset's__get_item__方法理解的线性索引)。但是,DataLoader没有__get_item__方法并反复调用__next__直到达到所需的索引似乎并不优雅。

显然,解决此问题的一种可能方法是定义自定义samplerbatch_sampler从抽象继承torch.utils.data.Sampler。但这似乎超出了检索一些特定样本的要求。

我想我在这里忽略了一些非常简单和明显的东西。任何建议表示赞赏!

Flo*_*sch 5

以防万一有类似问题的人在某个时候遇到这个问题:

我最终使用的快速而肮脏的解决方法是dataloader通过直接访问它的关联dataset属性来绕过训练循环中的。假设我们想快速检查我们的网络是否完全学习,方法是重复向它展示一个带有线性索引sample_idx(由数据集类定义)的单一、精心挑选的训练示例。

然后可以做这样的事情:

for i, _ in enumerate(self.dataloader):

    # Get training data
    # inputs, labels = data

    inputs, labels = self.dataloader.dataset[sample_idx]
    inputs = inputs.unsqueeze(0)
    labels = labels.unsqueeze(0)

    # Train the network
    # [...]
Run Code Online (Sandbox Code Playgroud)

编辑:

一个简短的评论,因为有些人似乎发现这种解决方法很有帮助:在使用此 hack 时,我发现DataLoader使用num_workers = 0. 否则,可能会出现内存分段错误,在这种情况下,您最终可能会得到看起来非常奇怪的训练数据。