Ala*_*ACK 4 python tensorflow pytorch huggingface-transformers
根据官方文档,该类Trainer“为 PyTorch 中大多数标准用例的功能完整训练提供了 API”。
然而,当我尝试Trainer在实践中实际使用时,我收到以下错误消息,这似乎表明 TensorFlow 目前正在幕后使用。
tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
Run Code Online (Sandbox Code Playgroud)
那么是哪一个呢?HuggingFace 转换器库是否使用 PyTorch 或 TensorFlow 进行内部实现Trainer?是否可以切换为仅使用 PyTorch?我似乎在 中找不到相关参数TrainingArguments。
为什么我的脚本不断打印出 TensorFlow 相关错误?不应该Trainer只使用 PyTorch 吗?
from transformers import GPT2Tokenizer
from transformers import GPT2LMHeadModel
from transformers import TextDataset
from transformers import DataCollatorForLanguageModeling
from transformers import Trainer
from transformers import TrainingArguments
import torch
# Load the GPT-2 tokenizer and LM head model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
lmhead_model = GPT2LMHeadModel.from_pretrained('gpt2')
# Load the training dataset and divide blocksize
train_dataset = TextDataset(
tokenizer=tokenizer,
file_path='./datasets/tinyshakespeare.txt',
block_size=64
)
# Create a data collator for preprocessing batches
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False
)
# Defining the training arguments
training_args = TrainingArguments(
output_dir='./models/tinyshakespeare', # output directory for checkpoints
overwrite_output_dir=True, # overwrite any existing content
per_device_train_batch_size=4, # sample batch size for training
dataloader_num_workers=1, # number of workers for dataloader
max_steps=100, # maximum number of training steps
save_steps=50, # after # steps checkpoints are saved
save_total_limit=5, # maximum number of checkpoints to save
prediction_loss_only=True, # only compute loss during prediction
learning_rate=3e-4, # learning rate
fp16=False, # use 16-bit (mixed) precision
optim='adamw_torch', # define the optimizer for training
lr_scheduler_type='linear', # define the learning rate scheduler
logging_steps=5, # after # steps logs are printed
report_to='none', # report to wandb, tensorboard, etc.
)
if __name__ == '__main__':
torch.multiprocessing.freeze_support()
trainer = Trainer(
model=lmhead_model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
)
trainer.train()
Run Code Online (Sandbox Code Playgroud)
这取决于模型的训练方式以及加载模型的方式。大多数流行的模型都transformers支持 PyTorch 和 Tensorflow(有时还支持 JAX)。
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import TFAutoModelForSeq2SeqLM
model_name = "google/flan-t5-large"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# This would work if the model's backend is PyTorch.
print(type(next(model.parameters())))
tf_model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)
# The `model.parameters()` would not work for Tensorflow,
# instead you can try `.summary()`
tf_model.summary()
Run Code Online (Sandbox Code Playgroud)
[出去]:
<class 'torch.nn.parameter.Parameter'>
Model: "tft5_for_conditional_generation"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
shared (Embedding) multiple 32899072
encoder (TFT5MainLayer) multiple 341231104
decoder (TFT5MainLayer) multiple 441918976
lm_head (Dense) multiple 32899072
=================================================================
Total params: 783,150,080
Trainable params: 783,150,080
Non-trainable params: 0
_________________________________________________________________
Run Code Online (Sandbox Code Playgroud)
也许是这样的:
def which_backend(model):
try:
model.parameters()
return 'torch'
except:
try:
model.summary()
return 'tensorflow'
except:
return 'I have no idea... Maybe JAX?'
Run Code Online (Sandbox Code Playgroud)
Trainer,那就是 PyTorch?答:是的,很可能该模型有 PyTorch 后端,并且训练循环(优化器、损失等)使用 PyTorch。但这Trainer()不是模型,而是包装对象。
TrainerTensorflow 后端模型,我应该使用TFTrainer?并不真地。在最新版本中transformers,该TFTrainer对象已被弃用,请参阅https://github.com/huggingface/transformers/pull/12706
.fit()如果您使用具有 Tensorflow 后端的模型,建议您使用 Keras 的 sklearn 式训练。
尝试检查您的transformers版本,很可能您使用的是过时的版本,该版本使用了一些已弃用的对象,例如 TextDataset (请参阅如何在创建数据集进行微调时解决“仅单个元素的整数张量可以转换为索引”错误GPT2 模型?)
在以后的版本中,pip install transformers>=4.26.1Trainer 很可能不应该激活 TF 警告,而使用 TFTrainer 会发出警告,建议用户改用 Keras。
| 归档时间: |
|
| 查看次数: |
1669 次 |
| 最近记录: |