PyTorch 数据加载器显示字符串数据集的奇怪行为

sta*_*010 6 python pytorch dataloader

我正在研究 NLP 问题并使用 PyTorch。由于某种原因,我的数据加载器返回格式错误的批次。我的输入数据包含句子和整数标签。这些句子可以是句子列表或标记列表列表。稍后我将在下游组件中将标记转换为整数。

list_labels = [ 0, 1, 0]

# List of sentences.
list_sentences = [ 'the movie is terrible',
                   'The Film was great.',
                   'It was just awful.']

# Or list of list of tokens.
list_sentences = [['the', 'movie', 'is', 'terrible'],
                  ['The', 'Film', 'was', 'great.'],
                  ['It', 'was', 'just', 'awful.']]
Run Code Online (Sandbox Code Playgroud)

我创建了以下自定义数据集:

import torch
from torch.utils.data import DataLoader, Dataset

class MyDataset(torch.utils.data.Dataset):

    def __init__(self, sentences, labels):

        self.sentences = sentences
        self.labels = labels

    def __getitem__(self, i):
        result = {}
        result['sentences'] = self.sentences[i]
        result['label'] = self.labels[i]
        return result

    def __len__(self):
        return len(self.labels)
Run Code Online (Sandbox Code Playgroud)

当我以句子列表的形式提供输入时,数据加载器会正确返回批量的完整句子。注意batch_size=2

list_sentences = [ 'the movie is terrible', 'The Film was great.', 'It was just awful.']
list_labels = [ 0, 1, 0]


dataset = MyDataset(list_sentences, list_labels)
dataloader = DataLoader(dataset, batch_size=2)

batch = next(iter(dataloader))
print(batch)
# {'sentences': ['the movie is terrible', 'The Film was great.'], <-- Great! 2 sentences in batch!
#  'label': tensor([0, 1])}
Run Code Online (Sandbox Code Playgroud)

该批次正确包含两个句子和两个标签,因为batch_size=2

然而,当我输入句子作为标记列表的预标记列表时,我得到了奇怪的结果:

list_sentences = [['the', 'movie', 'is', 'terrible'], ['The', 'Film', 'was', 'great.'], ['It', 'was', 'just', 'awful.']]
list_labels = [ 0, 1, 0]


dataset = MyDataset(list_sentences, list_labels)
dataloader = DataLoader(dataset, batch_size=2)

batch = next(iter(dataloader))
print(batch)
# {'sentences': [('the', 'The'), ('movie', 'Film'), ('is', 'was'), ('terrible', 'great.')], <-- WHAT?
#  'label': tensor([0, 1])}
Run Code Online (Sandbox Code Playgroud)

请注意,这批是一个包含单词对元组的sentences单个列表。我期望成为两个列表的列表,如下所示:sentences

{'sentences': [['the', 'movie', 'is', 'terrible'], ['The', 'Film', 'was', 'great.']
Run Code Online (Sandbox Code Playgroud)

到底是怎么回事?

Ber*_*iel 5

此行为是因为默认值在必须整理 s 时collate_fn执行以下操作list(即 的情况['sentences']):

# [...]
elif isinstance(elem, container_abcs.Sequence):
    # check to make sure that the elements in batch have consistent size
    it = iter(batch)
    elem_size = len(next(it))
    if not all(len(elem) == elem_size for elem in it):
        raise RuntimeError('each element in list of batch should be of equal size')
    transposed = zip(*batch)
    return [default_collate(samples) for samples in transposed]
Run Code Online (Sandbox Code Playgroud)

zip(*batch)发生“问题”是因为,在最后两行中,当批处理是 a container_abcs.Sequence(并且list是)时,它将递归调用,并且zip行为如下。

如你看到的:

batch = [['the', 'movie', 'is', 'terrible'], ['The', 'Film', 'was', 'great.']]
list(zip(*batch))

# [('the', 'The'), ('movie', 'Film'), ('is', 'was'), ('terrible', 'great.')]
Run Code Online (Sandbox Code Playgroud)

除了实现一个新的整理器并将其传递给DataLoader(..., collate_fn=mycollator). 例如,一个简单的丑陋的可能是:

def mycollator(batch):
    assert all('sentences' in x for x in batch)
    assert all('label' in x for x in batch)
    return {
        'sentences': [x['sentences'] for x in batch],
        'label': torch.tensor([x['label'] for x in batch])
    }
Run Code Online (Sandbox Code Playgroud)