torchtext BucketIterator 最小填充

pau*_*l41 1 python pytorch torchtext

我正在尝试使用 torchtext 中的 BucketIterator.splits 函数从 csv 文件加载数据以用于 CNN。除非我有一批最长的句子比最大的过滤器尺寸短,否则一切正常。

在我的示例中,我有大小为 3、4 和 5 的过滤器,因此如果最长的句子没有至少 5 个单词,我会收到错误消息。有没有办法让 BucketIterator 动态设置批次的填充,同时还设置最小填充长度?

这是我用于 BucketIterator 的代码:

train_iter, val_iter, test_iter = BucketIterator.splits((train, val, test), sort_key=lambda x: len(x.text), batch_size=batch_size, repeat=False, device=device)
Run Code Online (Sandbox Code Playgroud)

我希望有一种方法可以在 sort_key 或类似的东西上设置最小长度?

我试过这个,但它不起作用:

FILTER_SIZES = [3,4,5]
train_iter, val_iter, test_iter = BucketIterator.splits((train, val, test), sort_key=lambda x: len(x.text) if len(x.text) >= FILTER_SIZES[-1] else FILTER_SIZES[-1], batch_size=batch_size, repeat=False, device=device) 
Run Code Online (Sandbox Code Playgroud)

pau*_*l41 6

我查看了 torchtext 源代码以更好地了解 sort_key 正在做什么,并了解为什么我最初的想法行不通。

我不确定这是否是最好的解决方案,但我想出了一个有效的解决方案。我创建了一个标记器函数,如果文本短于最长过滤器长度,则填充文本,然后从那里创建 BucketIterator。

FILTER_SIZES = [3,4,5]
spacy_en = spacy.load('en')

def tokenizer(text):
    token = [t.text for t in spacy_en.tokenizer(text)]
    if len(token) < FILTER_SIZES[-1]:
        for i in range(0, FILTER_SIZES[-1] - len(token)):
            token.append('<PAD>')
    return token

TEXT = Field(sequential=True, tokenize=tokenizer, lower=True, tensor_type=torch.cuda.LongTensor)

train_iter, val_iter, test_iter = BucketIterator.splits((train, val, test), sort_key=lambda x: len(x.text), batch_size=batch_size, repeat=False, device=device)
Run Code Online (Sandbox Code Playgroud)