为什么这个数据集会尝试遍历最后一个元素
from torch.utils.data.dataset import Dataset
class DumbDataset(Dataset):
def __init__(self, dct):
self.dct = dct
self.mapping = dict(enumerate(dct))
def __getitem__(self, index):
return self.dct[self.mapping[index]]
def __len__(self):
print('called')
return len(self.dct)
ds = DumbDataset({'a': 'aword', 'b': 'another_words'})
for k in ds: print(k)
Run Code Online (Sandbox Code Playgroud)
这引发了 KeyError: 2,我不明白,因为对象的长度是 2。迭代器用完后不应该得到 StopIteration 吗?
为什么你的代码引起的原因KeyError是Dataset 没有实现 __iter__()在一个for循环的Python回落到从索引中使用时,因此0,呼吁__getitem__直到IndexError上升,为讨论在这里。您可以DumbDataset通过IndexError在索引超出范围时提高 an来进行修改以使其像这样工作
def __getitem__(self, index):
if index >= len(self): raise IndexError
return self.dct[self.mapping[index]]
Run Code Online (Sandbox Code Playgroud)
然后你的循环
for k in ds:
print(k)
Run Code Online (Sandbox Code Playgroud)
将按您的预期工作。另一方面,火炬数据集的典型模板是您可以使用索引循环遍历它们
for i in range(len(ds)):
k = ds[k]
print(k)
Run Code Online (Sandbox Code Playgroud)
或者将它们包装在 aDataLoader中,分批返回元素
generator = DataLoader(ds)
for k in generator:
print(k)
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
443 次 |
| 最近记录: |