我正在为音频分类任务训练 CNN,并且我正在使用带有自定义训练循环的 TensorFlow 2.0 RC(如其官方网站的本指南中所述)。我会发现有一个不错的进度条真的很方便,类似于通常的 Keras model.fit。
这是我的训练代码的大纲(我使用了 4 个 GPU,采用镜像分布策略):
strategy = distribute.MirroredStrategy()
distr_train_dataset = strategy.experimental_distribute_dataset(train_dataset)
if valid_dataset:
distr_valid_dataset = strategy.experimental_distribute_dataset(valid_dataset)
with strategy.scope():
model = build_model() # build the model
optimizer = # define optimizer
train_loss = # define training loss
train_metrics_1 = # AUC-ROC
train_metrics_2 = # AUC-PR
valid_metrics_1 = # AUC-ROC for validation
valid_metrics_2 = # AUC-PR for validation
# rescale loss
def compute_loss(labels, predictions):
per_example_loss = train_loss(labels, predictions)
return per_example_loss/config.batch_size
def train_step(batch):
audio_batch, …Run Code Online (Sandbox Code Playgroud)