next() 和 iter() 在 PyTorch 的 DataLoader() 中做什么

Leo*_*ckl 8 python iterator next pytorch dataloader

我有以下代码:

import torch
import numpy as np
import pandas as pd
from torch.utils.data import TensorDataset, DataLoader

# Load dataset
df = pd.read_csv(r'../iris.csv')

# Extract features and target
data = df.drop('target',axis=1).values
labels = df['target'].values

# Create tensor dataset
iris = TensorDataset(torch.FloatTensor(data),torch.LongTensor(labels))

# Create random batches
iris_loader = DataLoader(iris, batch_size=105, shuffle=True)

next(iter(iris_loader))
Run Code Online (Sandbox Code Playgroud)

上面的代码做了什么next()iter()做什么?我已经阅读了PyTorch 的文档,但仍然可以完全理解这里是什么next()iter()做什么。谁能帮忙解释一下?提前谢谢了。

Sco*_*ork 9

这些是 python 的内置函数,用于处理可迭代对象。

基本上iter()调用__iter__()上的方法iris_loader,它返回一个迭代。next()然后调用该__next__()迭代器上的方法以获取第一次迭代。next()再次运行将获得迭代器的第二项,以此类推。

这种逻辑经常发生在“幕后”,例如在运行for循环时。它调用__iter__()迭代__next__()器上的方法,然后调用返回的迭代器,直到到达迭代器的末尾。然后它引发 astopIteration并且循环停止。

有关更多详细信息和一些细微差别,请参阅文档:https : //docs.python.org/3/library/functions.html#iter

  • 这里的术语很重要,“iris_loader”是一个可迭代对象,将其传递给“iter()”会返回一个可以迭代的迭代器。您可以将这两个函数分开以更好地了解正在发生的情况。`i = iter(iris_loader)` 然后是 `next(i)`。如果您在笔记本中以交互方式运行此命令,请尝试多运行几次“next(i)”。每次运行“next(i)”时,它将返回迭代器的下一批大小为 105 的迭代器,直到没有剩余的批次为止。 (3认同)
  • 谢谢@ScootCork。所以简单来说,我可以说“iter()”只是迭代一个新的随机批次,而“next()”在输出中显示这个新的随机批次吗? (2认同)

eri*_*ric 5

接受的答案是正确的。当我对这个主题和迭代器/迭代器感到困惑时,我只是想给出一个补充答案。

我最初认为数据加载器是一个迭代器,所以这个想法iter(data_loader)似乎是多余的。但数据加载器是可迭代的,而不是迭代器。同样,列表不是迭代器,而是可迭代对象。next(x)如果您尝试直接在列表上运行,x您将得到TypeError: 'list' object is not an iterator. 要迭代列表,首先必须使用iter(x): 将其转换为迭代器,然后可以使用 开始迭代它next()

同样的逻辑也适用于数据加载器:它们是可迭代的,而不是迭代器,并且您可以使用iter(data_loader). 然后使用next()操作来逐步执行它们。您可以轻松地将其分解为多个步骤:

# define data loader (iterable)
iris_loader = DataLoader(iris, batch_size=105, shuffle=True)

# define iterator for use in training
iris_iterator = iter(iris_loader)

# extract batch
data_batch = next(iris_iterator)
Run Code Online (Sandbox Code Playgroud)