在 PyTorch 中使用 Dataloader 迭代数据集时出现 IndexError

ger*_*zhu 2 machine-learning computer-vision deep-learning pytorch

我使用 PyTorch 0.2 中的 Dataloader 迭代了我的数据集,如下所示:

dataloader = torch.utils.data.DataLoader(...)
data_iter = iter(dataloader)
data = data_iter.next()
Run Code Online (Sandbox Code Playgroud)

但引发了 IndexError 。

Traceback (most recent call last):
  File "main.py", line 193, in <module>
    data_target = data_target_iter.next()
  File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 201, in __next__
    return self._process_next_batch(batch)
  File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 221, in _process_next_batch
    raise batch.exc_type(batch.exc_msg)
IndexError: Traceback (most recent call last):
  File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 40, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 40, in <listcomp>
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/home/asr4/zhuminxian/adversarial/code/dataset/data_loader.py", line 33, in __getitem__
    return self.X_train[idx], self.y_train[idx]
IndexError: index 4196 is out of bounds for axis 0 with size 4135
Run Code Online (Sandbox Code Playgroud)

我想知道为什么索引超出范围。是Pytorch的bug吗?

我尝试再次运行我的代码,引发了相同的错误,但在不同的迭代中并且具有不同的越界索引。

Sha*_*hai 15

我的猜测是您data.Dataset.__len__没有正确重载,实际上len(dataloader.dataset)返回的数字大于len(self.X_train)
检查您在'/home/asr4/zhuminxian/adversarial/code/dataset/data_loader.py'.