如何在数据加载器中使用“collat​​e_fn”?

Sam*_*m V 5 python pytorch dataloader huggingface-transformers

我正在尝试使用3 个输入、3 个input_masks 和一个标签作为我的训练数据集的张量来训练一个预训练的 roberta 模型。

我使用以下代码执行此操作:

from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
batch_size = 32
# Create the DataLoader for our training set.
train_data = TensorDataset(train_AT, train_BT, train_CT, train_maskAT, train_maskBT, train_maskCT, labels_trainT)
train_dataloader = DataLoader(train_data, batch_size=batch_size)

# Create the Dataloader for our validation set.
validation_data = TensorDataset(val_AT, val_BT, val_CT, val_maskAT, val_maskBT, val_maskCT, labels_valT)
val_dataloader = DataLoader(validation_data, batch_size=batch_size)

# Pytorch Training
training_args = TrainingArguments(
    output_dir='C:/Users/samvd/Documents/Master/AppliedMachineLearning/FinalProject/results',          # output directory
    num_train_epochs=1,              # total # of training epochs
    per_device_train_batch_size=32,  # batch size per device during training
    per_device_eval_batch_size=32,   # batch size for evaluation
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='C:/Users/samvd/Documents/Master/AppliedMachineLearning/FinalProject/logs',            # directory for storing logs
)

trainer = Trainer(
    model=model,                          # the instantiated  Transformers model to be trained
    args=training_args,                   # training arguments, defined above
    train_dataset = train_data,           # training dataset
    eval_dataset = validation_data,       # evaluation dataset
)

trainer.train()
Run Code Online (Sandbox Code Playgroud)

但是,这给了我以下错误:

类型错误:vars() 参数必须具有dict属性

现在我发现这可能是因为我在使用collate_fn时没有使用DataLoader,但我真的找不到可以帮助我正确定义它的来源,以便训练师了解我输入的不同张量。

任何人都可以指出我正确的方向吗?

Abh*_*25t 13

基本上,collate_fn如果__getitem__来自 Dataset 子类的函数返回一个元组,则接收一个元组列表,或者如果您的 Dataset 子类仅返回一个元素,则接收一个普通列表。它的主要目标是创建您的批次,而无需花费大量时间手动实施。尝试将其视为一种胶水,您可以指定示例在批处理中粘合在一起的方式。如果您不使用它,PyTorch 只会batch_size像使用 torch.stack 一样将示例放在一起(不完全是这样,但它很简单)。

例如,假设您想要批量创建一系列不同维度的张量。下面的代码用 0 填充序列直到批处理的最大序列大小,这就是我们需要 collat​​e_fn 的原因,因为torch.stack在这种情况下标准批处理算法(仅使用)将不起作用,我们需要手动填充不同的序列在创建批处理之前将可变长度设置为相同的大小。

def collate_fn(data):
    """
       data: is a list of tuples with (example, label, length)
             where 'example' is a tensor of arbitrary shape
             and label/length are scalars
    """
    _, labels, lengths = zip(*data)
    max_len = max(lengths)
    n_ftrs = data[0][0].size(1)
    features = torch.zeros((len(data), max_len, n_ftrs))
    labels = torch.tensor(labels)
    lengths = torch.tensor(lengths)

    for i in range(len(data)):
        j, k = data[i][0].size(0), data[i][0].size(1)
        features[i] = torch.cat([data[i][0], torch.zeros((max_len - j, k))])

    return features.float(), labels.long(), lengths.long()
Run Code Online (Sandbox Code Playgroud)

上面的函数被提供给 DataLoader 中的 collat​​e_fn 参数,如下例:

DataLoader(toy_dataset, collate_fn=collate_fn, batch_size=5)
Run Code Online (Sandbox Code Playgroud)

使用此 collat​​e_fn 函数,您将始终拥有一个张量,其中所有示例都具有相同的大小。因此,当您将这些数据提供给 forward() 函数时,您需要使用长度来获取原始数据,而不是在计算中使用那些无意义的零。

来源:Pytorch 论坛