Jai*_*tas 3 python distributed deep-learning pytorch
所以我有一个比我的内存更大的文本文件,我想在 PyTorch 中创建一个逐行读取的数据集,这样我就不必将其全部加载到内存中。我发现 pytorchIterableDataset
作为我的问题的潜在解决方案。它仅在使用 1 个工作人员时按预期工作,如果使用多个工作人员,则会创建重复的记录。让我给你看一个例子:
有一个testfile.txt
包含:
0 - Dummy line
1 - Dummy line
2 - Dummy line
3 - Dummy line
4 - Dummy line
5 - Dummy line
6 - Dummy line
7 - Dummy line
8 - Dummy line
9 - Dummy line
Run Code Online (Sandbox Code Playgroud)
定义一个IterableDataset:
0 - Dummy line
1 - Dummy line
2 - Dummy line
3 - Dummy line
4 - Dummy line
5 - Dummy line
6 - Dummy line
7 - Dummy line
8 - Dummy line
9 - Dummy line
Run Code Online (Sandbox Code Playgroud)
我们现在可以测试它:
class CustomIterableDatasetv1(IterableDataset):
def __init__(self, filename):
#Store the filename in object's memory
self.filename = filename
def preprocess(self, text):
### Do something with text here
text_pp = text.lower().strip()
###
return text_pp
def line_mapper(self, line):
#Splits the line into text and label and applies preprocessing to the text
text, label = line.split('-')
text = self.preprocess(text)
return text, label
def __iter__(self):
#Create an iterator
file_itr = open(self.filename)
#Map each element using the line_mapper
mapped_itr = map(self.line_mapper, file_itr)
return mapped_itr
Run Code Online (Sandbox Code Playgroud)
它输出:
('0',) (' Dummy line\n',)
('1',) (' Dummy line\n',)
('2',) (' Dummy line\n',)
('3',) (' Dummy line\n',)
('4',) (' Dummy line\n',)
('5',) (' Dummy line\n',)
('6',) (' Dummy line\n',)
('7',) (' Dummy line\n',)
('8',) (' Dummy line\n',)
('9',) (' Dummy line',)
Run Code Online (Sandbox Code Playgroud)
那是对的。但是如果我将工人数量更改为 2,输出就会变成
('0',) (' Dummy line\n',)
('0',) (' Dummy line\n',)
('1',) (' Dummy line\n',)
('1',) (' Dummy line\n',)
('2',) (' Dummy line\n',)
('2',) (' Dummy line\n',)
('3',) (' Dummy line\n',)
('3',) (' Dummy line\n',)
('4',) (' Dummy line\n',)
('4',) (' Dummy line\n',)
('5',) (' Dummy line\n',)
('5',) (' Dummy line\n',)
('6',) (' Dummy line\n',)
('6',) (' Dummy line\n',)
('7',) (' Dummy line\n',)
('7',) (' Dummy line\n',)
('8',) (' Dummy line\n',)
('8',) (' Dummy line\n',)
('9',) (' Dummy line',)
('9',) (' Dummy line',)
Run Code Online (Sandbox Code Playgroud)
这是不正确的,因为在数据加载器中为每个工作人员创建每个样本的副本也是不正确的。
有没有办法用pytorch解决这个问题?因此,可以创建一个数据加载器来不加载内存中的所有文件,并支持多个工作人员。
所以我在火炬讨论论坛https://discuss.pytorch.org/t/iterable-pytorch-dataset-with-multiple-workers/135475/3中找到了答案,他们指出我应该使用工人信息来切片连续到批量大小。
新的数据集将如下所示:
class CustomIterableDatasetv1(IterableDataset):
def __init__(self, filename):
#Store the filename in object's memory
self.filename = filename
def preprocess(self, text):
### Do something with text here
text_pp = text.lower().strip()
###
return text_pp
def line_mapper(self, line):
#Splits the line into text and label and applies preprocessing to the text
text, label = line.split('-')
text = self.preprocess(text)
return text, label
def __iter__(self):
worker_total_num = torch.utils.data.get_worker_info().num_workers
worker_id = torch.utils.data.get_worker_info().id
#Create an iterator
file_itr = open(self.filename)
#Map each element using the line_mapper
mapped_itr = map(self.line_mapper, file_itr)
#Add multiworker functionality
mapped_itr = itertools.islice(mapped_itr, worker_id, None, worker_total_num)
return mapped_itr
Run Code Online (Sandbox Code Playgroud)
特别感谢@Ivan,他也指出了切片解决方案。
对于两个工作人员,它返回与仅 1 个工作人员相同的数据
您可以使用util 访问Dataset
s函数内的工作标识符。这意味着您可以单步执行迭代器并根据工作人员id添加偏移量。您可以包装一个迭代器,它允许您步进索引以及.__iter__
torch.utils.data.get_worker_info
itertools.islice
start
step
这是一个最小的例子:
class DS(IterableDataset):
def __init__(self, batch_size):
super().__init__()
self.batch_size = batch_size
def __iter__(self):
uid = torch.utils.data.get_worker_info().id
itr = islice(range(10), uid, None, self.batch_size)
return itr
Run Code Online (Sandbox Code Playgroud)
即使我们使用的是:循环数据加载器也会产生唯一的实例num_workers > 1
:
>>> for x in DataLoader(DS(batch_size=2), batch_size=2, num_workers=2):
... print(x)
tensor([0, 2])
tensor([1, 3])
tensor([4, 6])
tensor([5, 7])
tensor([8])
tensor([9])
Run Code Online (Sandbox Code Playgroud)
对于你的情况你可以这样做:
def __iter__(self):
# create an iterator
file_itr = open(self.filename)
# map each element using the line_mapper
mapped_itr = map(self.line_mapper, file_itr)
# wrap the iterator
step_itr = islice(mapped_itr, uid, None, self.batch_size)
return step_itr
Run Code Online (Sandbox Code Playgroud)
归档时间: |
|
查看次数: |
8981 次 |
最近记录: |