HuggingFace 转换器中的默认“Trainer”类是否在幕后使用 PyTorch 或 TensorFlow?

Ala*_*ACK 4 python tensorflow pytorch huggingface-transformers

问题

根据官方文档,该类Trainer“为 PyTorch 中大多数标准用例的功能完整训练提供了 API”。

然而,当我尝试Trainer在实践中实际使用时,我收到以下错误消息,这似乎表明 TensorFlow 目前正在幕后使用。

tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
Run Code Online (Sandbox Code Playgroud)

那么是哪一个呢?HuggingFace 转换器库是否使用 PyTorch 或 TensorFlow 进行内部实现Trainer?是否可以切换为仅使用 PyTorch?我似乎在 中找不到相关参数TrainingArguments

为什么我的脚本不断打印出 TensorFlow 相关错误?不应该Trainer只使用 PyTorch 吗?

源代码

from transformers import GPT2Tokenizer
from transformers import GPT2LMHeadModel
from transformers import TextDataset
from transformers import DataCollatorForLanguageModeling
from transformers import Trainer
from transformers import TrainingArguments

import torch

# Load the GPT-2 tokenizer and LM head model
tokenizer    = GPT2Tokenizer.from_pretrained('gpt2')
lmhead_model = GPT2LMHeadModel.from_pretrained('gpt2')

# Load the training dataset and divide blocksize
train_dataset = TextDataset(
    tokenizer=tokenizer,
    file_path='./datasets/tinyshakespeare.txt',
    block_size=64
)

# Create a data collator for preprocessing batches
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)

# Defining the training arguments
training_args = TrainingArguments(
    output_dir='./models/tinyshakespeare', # output directory for checkpoints
    overwrite_output_dir=True,             # overwrite any existing content

    per_device_train_batch_size=4,         # sample batch size for training
    dataloader_num_workers=1,              # number of workers for dataloader
    max_steps=100,                         # maximum number of training steps
    save_steps=50,                         # after # steps checkpoints are saved
    save_total_limit=5,                    # maximum number of checkpoints to save

    prediction_loss_only=True,             # only compute loss during prediction
    learning_rate=3e-4,                    # learning rate
    fp16=False,                            # use 16-bit (mixed) precision

    optim='adamw_torch',                   # define the optimizer for training
    lr_scheduler_type='linear',            # define the learning rate scheduler

    logging_steps=5,                       # after # steps logs are printed
    report_to='none',                      # report to wandb, tensorboard, etc.
)

if __name__ == '__main__':
    torch.multiprocessing.freeze_support()

    trainer = Trainer(
        model=lmhead_model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_dataset,
    )

    trainer.train()
Run Code Online (Sandbox Code Playgroud)

alv*_*vas 5

这取决于模型的训练方式以及加载模型的方式。大多数流行的模型都transformers支持 PyTorch 和 Tensorflow(有时还支持 JAX)。

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import TFAutoModelForSeq2SeqLM

model_name = "google/flan-t5-large"

model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# This would work if the model's backend is PyTorch.
print(type(next(model.parameters())))


tf_model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)

# The `model.parameters()` would not work for Tensorflow,
# instead you can try `.summary()`
tf_model.summary()
Run Code Online (Sandbox Code Playgroud)

[出去]:

<class 'torch.nn.parameter.Parameter'>

Model: "tft5_for_conditional_generation"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 shared (Embedding)          multiple                  32899072  
                                                                 
 encoder (TFT5MainLayer)     multiple                  341231104 
                                                                 
 decoder (TFT5MainLayer)     multiple                  441918976 
                                                                 
 lm_head (Dense)             multiple                  32899072  
                                                                 
=================================================================
Total params: 783,150,080
Trainable params: 783,150,080
Non-trainable params: 0
_________________________________________________________________

Run Code Online (Sandbox Code Playgroud)

也许是这样的:


def which_backend(model):
  try:
    model.parameters()
    return 'torch'
  except:
    try:
      model.summary()
      return 'tensorflow'
    except:
      return 'I have no idea... Maybe JAX?'

Run Code Online (Sandbox Code Playgroud)

问:那么如果我使用Trainer,那就是 PyTorch?

答:是的,很可能该模型有 PyTorch 后端,并且训练循环(优化器、损失等)使用 PyTorch。但这Trainer()不是模型,而是包装对象。

问:如果我想用于TrainerTensorflow 后端模型,我应该使用TFTrainer?

并不真地。在最新版本中transformers,该TFTrainer对象已被弃用,请参阅https://github.com/huggingface/transformers/pull/12706

.fit()如果您使用具有 Tensorflow 后端的模型,建议您使用 Keras 的 sklearn 式训练。

问:为什么我的脚本不断打印出 TensorFlow 相关错误?Trainer 不应该只使用 PyTorch 吗?

尝试检查您的transformers版本,很可能您使用的是过时的版本,该版本使用了一些已弃用的对象,例如 TextDataset (请参阅如何在创建数据集进行微调时解决“仅单个元素的整数张量可以转换为索引”错误GPT2 模型?

在以后的版本中,pip install transformers>=4.26.1Trainer 很可能不应该激活 TF 警告,而使用 TFTrainer 会发出警告,建议用户改用 Keras。