Aak*_*W.S 5 python pytorch dataloader
如何从 DataLoader 加载整个数据集?我只得到一批数据集。
这是我的代码
dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=64)
images, labels = next(iter(dataloader))
Run Code Online (Sandbox Code Playgroud)
小智 11
另一种选择是直接获取整个数据集,而不使用数据加载器,如下所示:
images, labels = dataset[:]
Run Code Online (Sandbox Code Playgroud)
小智 6
您可以设置batch_size=dataset.__len__()
数据集是 torch Dataset
,否则batch_szie=len(dataset)
应该可以工作。
请注意,这可能需要大量内存,具体取决于您的数据集。
我不确定您是想在网络训练之外的其他地方使用数据集(例如检查图像)还是想在训练期间迭代批次。
遍历数据集
要么按照 Usman Ali 的回答(可能会溢出)你的记忆,要么你可以
for i in range(len(dataset)): # or i, image in enumerate(dataset)
images, labels = dataset[i] # or whatever your dataset returns
Run Code Online (Sandbox Code Playgroud)
你能写dataset[i]
,因为你实现__len__
并__getitem__
在Dataset
类(只要它是Pytorch的一个子Dataset
类)。
从数据加载器获取所有批次
我理解你的问题的方式是你想检索所有批次来训练网络。您应该明白它iter
为您提供了数据加载器的迭代器(如果您不熟悉迭代器的概念,请参阅维基百科条目)。next
告诉迭代器给你下一个项目。
因此,与遍历列表的迭代器相反,数据加载器总是返回下一个项目。列表迭代器在某个时候停止。我假设你有一些类似的时期和每个时期的一些步骤。然后你的代码看起来像这样
for i in range(epochs):
# some code
for j in range(steps_per_epoch):
images, labels = next(iter(dataloader))
prediction = net(images)
loss = net.loss(prediction, labels)
...
Run Code Online (Sandbox Code Playgroud)
小心next(iter(dataloader))
. 如果你想遍历一个列表,这也可能有效,因为 Python 缓存对象,但你可能会在每次再次从索引 0 开始时得到一个新的迭代器。为了避免这种情况,将迭代器取出到顶部,如下所示:
iterator = iter(dataloader)
for i in range(epochs):
for j in range(steps_per_epoch):
images, labels = next(iterator)
Run Code Online (Sandbox Code Playgroud)
归档时间: |
|
查看次数: |
14454 次 |
最近记录: |