具有多个工作人员的可迭代 pytorch 数据集

Jai*_*tas 3 python distributed deep-learning pytorch

所以我有一个比我的内存更大的文本文件,我想在 PyTorch 中创建一个逐行读取的数据集,这样我就不必将其全部加载到内存中。我发现 pytorchIterableDataset作为我的问题的潜在解决方案。它仅在使用 1 个工作人员时按预期工作,如果使用多个工作人员,则会创建重复的记录。让我给你看一个例子:

有一个testfile.txt包含:

0 - Dummy line
1 - Dummy line
2 - Dummy line
3 - Dummy line
4 - Dummy line
5 - Dummy line
6 - Dummy line
7 - Dummy line
8 - Dummy line
9 - Dummy line
Run Code Online (Sandbox Code Playgroud)

定义一个IterableDataset:

0 - Dummy line
1 - Dummy line
2 - Dummy line
3 - Dummy line
4 - Dummy line
5 - Dummy line
6 - Dummy line
7 - Dummy line
8 - Dummy line
9 - Dummy line
Run Code Online (Sandbox Code Playgroud)

我们现在可以测试它:

class CustomIterableDatasetv1(IterableDataset):

    def __init__(self, filename):

        #Store the filename in object's memory
        self.filename = filename

    def preprocess(self, text):

        ### Do something with text here
        text_pp = text.lower().strip()
        ###

        return text_pp

    def line_mapper(self, line):
        
        #Splits the line into text and label and applies preprocessing to the text
        text, label = line.split('-')
        text = self.preprocess(text)

        return text, label


    def __iter__(self):

        #Create an iterator
        file_itr = open(self.filename)

        #Map each element using the line_mapper
        mapped_itr = map(self.line_mapper, file_itr)
        
        return mapped_itr
Run Code Online (Sandbox Code Playgroud)

它输出:



('0',) (' Dummy line\n',)
('1',) (' Dummy line\n',)
('2',) (' Dummy line\n',)
('3',) (' Dummy line\n',)
('4',) (' Dummy line\n',)
('5',) (' Dummy line\n',)
('6',) (' Dummy line\n',)
('7',) (' Dummy line\n',)
('8',) (' Dummy line\n',)
('9',) (' Dummy line',)
Run Code Online (Sandbox Code Playgroud)

那是对的。但是如果我将工人数量更改为 2,输出就会变成

('0',) (' Dummy line\n',)
('0',) (' Dummy line\n',)
('1',) (' Dummy line\n',)
('1',) (' Dummy line\n',)
('2',) (' Dummy line\n',)
('2',) (' Dummy line\n',)
('3',) (' Dummy line\n',)
('3',) (' Dummy line\n',)
('4',) (' Dummy line\n',)
('4',) (' Dummy line\n',)
('5',) (' Dummy line\n',)
('5',) (' Dummy line\n',)
('6',) (' Dummy line\n',)
('6',) (' Dummy line\n',)
('7',) (' Dummy line\n',)
('7',) (' Dummy line\n',)
('8',) (' Dummy line\n',)
('8',) (' Dummy line\n',)
('9',) (' Dummy line',)
('9',) (' Dummy line',)
Run Code Online (Sandbox Code Playgroud)

这是不正确的,因为在数据加载器中为每个工作人员创建每个样本的副本也是不正确的。

有没有办法用pytorch解决这个问题?因此,可以创建一个数据加载器来不加载内存中的所有文件,并支持多个工作人员。

Jai*_*tas 7

所以我在火炬讨论论坛https://discuss.pytorch.org/t/iterable-pytorch-dataset-with-multiple-workers/135475/3中找到了答案,他们指出我应该使用工人信息来切片连续到批量大小。

新的数据集将如下所示:

class CustomIterableDatasetv1(IterableDataset):

    def __init__(self, filename):

        #Store the filename in object's memory
        self.filename = filename

    def preprocess(self, text):

        ### Do something with text here
        text_pp = text.lower().strip()
        ###

        return text_pp

    def line_mapper(self, line):
        
        #Splits the line into text and label and applies preprocessing to the text
        text, label = line.split('-')
        text = self.preprocess(text)

        return text, label


    def __iter__(self):
        worker_total_num = torch.utils.data.get_worker_info().num_workers
        worker_id = torch.utils.data.get_worker_info().id
        #Create an iterator
        file_itr = open(self.filename)

        #Map each element using the line_mapper
        mapped_itr = map(self.line_mapper, file_itr)
        
        #Add multiworker functionality
        mapped_itr = itertools.islice(mapped_itr, worker_id, None, worker_total_num)

        return mapped_itr
Run Code Online (Sandbox Code Playgroud)

特别感谢@Ivan,他也指出了切片解决方案。

对于两个工作人员,它返回与仅 1 个工作人员相同的数据


Iva*_*van 5

您可以使用util 访问Datasets函数内的工作标识符。这意味着您可以单步执行迭代器并根据工作人员id添加偏移量。您可以包装一个迭代器,它允许您步进索引以及.__iter__torch.utils.data.get_worker_infoitertools.islicestartstep

这是一个最小的例子:

class DS(IterableDataset):
    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size

    def __iter__(self):
        uid = torch.utils.data.get_worker_info().id
        itr = islice(range(10), uid, None, self.batch_size)
        return itr
Run Code Online (Sandbox Code Playgroud)

即使我们使用的是:循环数据加载器也会产生唯一的实例num_workers > 1

>>> for x in DataLoader(DS(batch_size=2), batch_size=2, num_workers=2):
...     print(x)
tensor([0, 2])
tensor([1, 3])
tensor([4, 6])
tensor([5, 7])
tensor([8])
tensor([9])
Run Code Online (Sandbox Code Playgroud)

对于你的情况你可以这样做:

    def __iter__(self):
        # create an iterator
        file_itr = open(self.filename)

        # map each element using the line_mapper
        mapped_itr = map(self.line_mapper, file_itr)
    
        # wrap the iterator
        step_itr = islice(mapped_itr, uid, None, self.batch_size)

        return step_itr
Run Code Online (Sandbox Code Playgroud)