Huggingface Trainer():K-Fold 交叉验证

Max*_*rat 5 python cross-validation bert-language-model k-fold huggingface-transformers

我正在遵循TowardsDataScience 的教程,使用 Huggingface Trainer 进行文本分类。为了获得更强大的模型,我想要进行 K 折交叉验证,但我不确定如何使用 Huggingface Trainer 来完成此操作。Trainer 是否有内置功能,或者如何在此处进行交叉验证?

提前致谢!

Rua*_*uan 1

最好的方法不是像本教程那样从头开始编写自己的数据集读取器,而是使用Hugging Face 数据集库,该库已经与 Hugging Face 转换器集成。

以下是有关如何使本教程适应数据集库的分步指南:

首先,我们必须将教程中的原始 CSV 转换为可以使用 load_dataset 函数加载的内容。我们将预处理原始train.csvCSV 并保存文件new_train.csvvalidation.csv.

from sklearn.model_selection import train_test_split
data = pd.read_csv("train.csv")
data["label"] = data["sentiment"]
train, validation = train_test_split(data, test_size=0.2)
train.to_csv("new_train.csv")
validation.to_csv("validation.csv")
Run Code Online (Sandbox Code Playgroud)

文档提供了有关如何创建自己的交叉验证拆分的示例。在这里,我们将其适应我们的用例:

val_ds = datasets.load_dataset("csv", data_files={"validation": "validation.csv"}, split=[f"validation[{k}%:{k+10}%]" for k in range(0, 100, 10)])
train_ds = datasets.load_dataset("csv", data_files={"train": "new_train.csv"}, split=[f"train[:{k}%]+train[{k+10}%:]" for k in range(0, 100, 10)])
Run Code Online (Sandbox Code Playgroud)

现在我们对分割进行标记:

def preprocess_function(examples):
    # Tokenize the texts
    args = ((examples["review"],))
    result = tokenizer(*args, padding=True, max_length=128, truncation=True)
    result["label"] = examples["label"]
    return result

for idx, item in enumerate(train_ds):
    train_ds[idx] = train_ds[idx].map(
        preprocess_function,
        batched=True,
        desc="Running tokenizer on dataset",
    )

for idx, item in enumerate(val_ds):
    val_ds[idx] = val_ds[idx].map(
        preprocess_function,
        batched=True,
        desc="Running tokenizer on dataset",
    )
Run Code Online (Sandbox Code Playgroud)

之后,您只需循环遍历拆分并将它们传递给您的Trainer.

for train_dataset, val_dataset in zip(train_ds, val_ds):
    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=compute_metrics,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
    )
...
Run Code Online (Sandbox Code Playgroud)