Torch 数据集循环太远

Sam*_*fer 3 python pytorch

为什么这个数据集会尝试遍历最后一个元素

from torch.utils.data.dataset import Dataset
class DumbDataset(Dataset):
    def __init__(self, dct):
        self.dct = dct
        self.mapping = dict(enumerate(dct))
    def __getitem__(self, index):
        return self.dct[self.mapping[index]]

    def __len__(self):
        print('called')
        return len(self.dct)

ds = DumbDataset({'a': 'aword', 'b': 'another_words'})

for k in ds: print(k)
Run Code Online (Sandbox Code Playgroud)

这引发了 KeyError: 2,我不明白,因为对象的长度是 2。迭代器用完后不应该得到 StopIteration 吗?

Jon*_*ton 5

为什么你的代码引起的原因KeyErrorDataset 没有实现 __iter__()在一个for循环的Python回落到从索引中使用时,因此0,呼吁__getitem__直到IndexError上升,为讨论在这里。您可以DumbDataset通过IndexError在索引超出范围时提高 an来进行修改以使其像这样工作

def __getitem__(self, index):
    if index >= len(self): raise IndexError
    return self.dct[self.mapping[index]]
Run Code Online (Sandbox Code Playgroud)

然后你的循环

for k in ds:
    print(k)
Run Code Online (Sandbox Code Playgroud)

将按您的预期工作。另一方面,火炬数据集的典型模板是您可以使用索引循环遍历它们

for i in range(len(ds)):
    k = ds[k]
    print(k)
Run Code Online (Sandbox Code Playgroud)

或者将它们包装在 aDataLoader中,分批返回元素

generator = DataLoader(ds)
for k in generator:
    print(k)
Run Code Online (Sandbox Code Playgroud)