当达到特定的验证准确度时如何停止训练?

gls*_*lmn 4 python deep-learning conv-neural-network keras tensorflow

我正在训练一个卷积网络,一旦验证错误达到​​ 90%,我想停止训练。我考虑过使用 EarlyStopping 并将基线设置为 0.90,但是只要验证准确度低于给定时期数的基线(此处仅为 0),它就会停止训练。所以我的代码是:

es=EarlyStopping(monitor='val_acc',mode='auto',verbose=1,baseline=.90,patience=0)
history = model.fit(training_images, training_labels, validation_data=(test_images, test_labels), epochs=30, verbose=2,callbacks=[es])
Run Code Online (Sandbox Code Playgroud)

当我使用此代码时,我的训练在第一个具有给定结果的 epoch 后停止:

训练 60000 个样本,验证 10000 个样本

纪元 1/30 60000/60000 - 7s - 损失:0.4600 - acc:0.8330 - val_loss:0.3426 - val_acc:0.8787

一旦验证准确率达到 90% 或以上,我还能尝试停止训练吗?

下面是代码的其余部分:

  tf.keras.layers.Conv2D(64, (3,3), activation='relu', input_shape=(28, 28, 1)),
  tf.keras.layers.MaxPooling2D(2, 2),
  tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(28, 28, 1)),
  tf.keras.layers.MaxPooling2D(2, 2),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(152, activation='relu'),
  tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer=Adam(learning_rate=0.001),loss='sparse_categorical_crossentropy', metrics=['accuracy'])
es=EarlyStopping(monitor='val_acc',mode='auto',verbose=1,baseline=.90,patience=0)
history = model.fit(training_images, training_labels, validation_data=(test_images, test_labels), epochs=30, verbose=2,callbacks=[es])
Run Code Online (Sandbox Code Playgroud)

谢谢!

seb*_*-sz 6

Early Stopping Callback 将搜索停止增加(或减少)的值,因此它不适合您的问题。但是tf.keras允许您使用自定义回调

对于您的示例:

class MyThresholdCallback(tf.keras.callbacks.Callback):
    def __init__(self, threshold):
        super(MyThresholdCallback, self).__init__()
        self.threshold = threshold

    def on_epoch_end(self, epoch, logs=None): 
        val_acc = logs["val_acc"]
        if val_acc >= self.threshold:
            self.model.stop_training = True
Run Code Online (Sandbox Code Playgroud)

对于以上TF 2.3版本,您可能需要使用"val_accuracy"替代"val_acc"。感谢克里斯蒂安·威斯布鲁克在评论中的注释。

上述回调,在每个纪元结束时,将从所有可用日志中提取验证准确度。然后它将它与用户定义的阈值(在您的情况下为 90%)进行比较。如果满足标准,则训练将停止。

有了它,你可以简单地调用:

my_callback = MyThresholdCallback(threshold=0.9)
history = model.fit(training_images, training_labels, validation_data=(test_images, test_labels), epochs=30, verbose=2, callbacks=[my_callback])
Run Code Online (Sandbox Code Playgroud)

或者,def on_batch_end(...)如果您想立即停止,您可以使用。但是,这需要参数batch, logs而不是epoch, logs.

  • 当我在 TensorFlow 2.3 上尝试此自定义回调时,我在调用日志 [“val_acc”] 时遇到 KeyError,并且必须替换该键“val_accuracy”。这个回调很有帮助! (2认同)