如何从 PyTorch 中的数据加载器获取整个数据集

Aak*_*W.S 5 python pytorch dataloader

如何从 DataLoader 加载整个数据集?我只得到一批数据集。

这是我的代码

dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=64)
images, labels = next(iter(dataloader))
Run Code Online (Sandbox Code Playgroud)

小智 11

另一种选择是直接获取整个数据集,而不使用数据加载器,如下所示:

images, labels = dataset[:]
Run Code Online (Sandbox Code Playgroud)


小智 6

您可以设置batch_size=dataset.__len__()数据集是 torch Dataset,否则batch_szie=len(dataset)应该可以工作。

请注意,这可能需要大量内存,具体取决于您的数据集。

  • 您还可以对 torch 数据集执行“len(dataset)”。能够做到这一点是实现 `-_len__()` 方法的唯一目的。 (5认同)

Flo*_*ume 6

我不确定您是想在网络训练之外的其他地方使用数据集(例如检查图像)还是想在训练期间迭代批次。

遍历数据集

要么按照 Usman Ali 的回答(可能会溢出)你的记忆,要么你可以

for i in range(len(dataset)): # or i, image in enumerate(dataset)
    images, labels = dataset[i] # or whatever your dataset returns
Run Code Online (Sandbox Code Playgroud)

你能写dataset[i],因为你实现__len____getitem__Dataset类(只要它是Pytorch的一个子Dataset类)。

从数据加载器获取所有批次

我理解你的问题的方式是你想检索所有批次来训练网络。您应该明白它iter为您提供了数据加载器的迭代器(如果您不熟悉迭代器的概念,请参阅维基百科条目)。next告诉迭代器给你下一个项目。

因此,与遍历列表的迭代器相反,数据加载器总是返回下一个项目。列表迭代器在某个时候停止。我假设你有一些类似的时期和每个时期的一些步骤。然后你的代码看起来像这样

for i in range(epochs):
    # some code
    for j in range(steps_per_epoch):
        images, labels = next(iter(dataloader))
        prediction = net(images)
        loss = net.loss(prediction, labels)
        ...
Run Code Online (Sandbox Code Playgroud)

小心next(iter(dataloader)). 如果你想遍历一个列表,这也可能有效,因为 Python 缓存对象,但你可能会在每次再次从索引 0 开始时得到一个新的迭代器。为了避免这种情况,将迭代器取出到顶部,如下所示:

iterator = iter(dataloader)
for i in range(epochs):
    for j in range(steps_per_epoch):
        images, labels = next(iterator)
Run Code Online (Sandbox Code Playgroud)