小编pyg*_*rix的帖子

TensorFlow 2.0:在自定义训练循环中显示进度条

我正在为音频分类任务训练 CNN,并且我正在使用带有自定义训练循环的 TensorFlow 2.0 RC(如其官方网站的本指南中所述)。我会发现有一个不错的进度条真的很方便,类似于通常的 Keras model.fit

这是我的训练代码的大纲(我使用了 4 个 GPU,采用镜像分布策略):

strategy = distribute.MirroredStrategy()

distr_train_dataset = strategy.experimental_distribute_dataset(train_dataset)

if valid_dataset:
    distr_valid_dataset = strategy.experimental_distribute_dataset(valid_dataset)

with strategy.scope():

    model = build_model() # build the model

    optimizer = # define optimizer
    train_loss = # define training loss
    train_metrics_1 = # AUC-ROC
    train_metrics_2 = # AUC-PR
    valid_metrics_1 = # AUC-ROC for validation
    valid_metrics_2 = # AUC-PR for validation

    # rescale loss
    def compute_loss(labels, predictions):
        per_example_loss = train_loss(labels, predictions)
        return per_example_loss/config.batch_size

    def train_step(batch):
        audio_batch, …
Run Code Online (Sandbox Code Playgroud)

python progress-bar tensorflow

4
推荐指数
2
解决办法
4011
查看次数

标签 统计

progress-bar ×1

python ×1

tensorflow ×1