当精度已经达到1.0时,停止在Keras中进行训练

Aly*_*ono 3 python machine-learning neural-network keras

当精度已经达到1.0时,如何停止Keras Training?我尝试监视损失值,但是在精度已经达到1的情况下,我没有尝试停止训练。

我没有运气尝试下面的代码:

stopping_criterions =[
    EarlyStopping(monitor='loss', min_delta=0, patience = 1000),
    EarlyStopping(monitor='acc', base_line=1.0, patience =0)

]

model.summary()
model.compile(Adam(), loss='binary_crossentropy', metrics=['accuracy']) 
model.fit(scaled_train_samples, train_labels, batch_size=1000, epochs=1000000, callbacks=[stopping_criterions], shuffle = True, verbose=2)
Run Code Online (Sandbox Code Playgroud)

更新:

即使精度仍然不是1.0,训练也会立即从第一个纪元停止。

在此处输入图片说明

请帮忙。

小智 9

据我所知,使用带有基线回调的 EarlyStopping 在这里不起作用。“基线”是您应该继续训练的监控变量的最小值(此处为准确度)。这里的基线是 1.0,在第一个纪元结束时基线小于“准确度”(显然你不能指望第一个纪元本身的“准确度”为 1.0)并且由于耐心设置为零,训练停止在第一个纪元本身,因为基线大于准确度。使用自定义回调在这里完成工作。

class MyThresholdCallback(tf.keras.callbacks.Callback):
    def __init__(self, threshold):
        super(MyThresholdCallback, self).__init__()
        self.threshold = threshold

    def on_epoch_end(self, epoch, logs=None): 
        accuracy = logs["acc"]
        if accuracy >= self.threshold:
            self.model.stop_training = True
Run Code Online (Sandbox Code Playgroud)

并在 model.fit 中调用回调

callback=MyThresholdCallback(threshold=1.0)
model.fit(scaled_train_samples, train_labels, batch_size=1000, epochs=1000000, callbacks=[callback], shuffle = True, verbose=2)
Run Code Online (Sandbox Code Playgroud)


tod*_*day 5

更新:我不知道为什么EarlyStopping在这种情况下不起作用。相反,我定义了一个自定义回调,该回调在acc(或val_acc)达到指定的基线时停止训练:

from keras.callbacks import Callback

class TerminateOnBaseline(Callback):
    """Callback that terminates training when either acc or val_acc reaches a specified baseline
    """
    def __init__(self, monitor='acc', baseline=0.9):
        super(TerminateOnBaseline, self).__init__()
        self.monitor = monitor
        self.baseline = baseline

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        acc = logs.get(self.monitor)
        if acc is not None:
            if acc >= self.baseline:
                print('Epoch %d: Reached baseline, terminating training' % (epoch))
                self.model.stop_training = True
Run Code Online (Sandbox Code Playgroud)

您可以像这样使用它:

callbacks = [TerminateOnBaseline(monitor='acc', baseline=0.8)]
callbacks = [TerminateOnBaseline(monitor='val_acc', baseline=0.95)]
Run Code Online (Sandbox Code Playgroud)

注意:此解决方案不起作用。

如果要在训练(或验证)准确性完全达到100%时停止训练,请使用EarlyStoppingcallback并将baseline参数设置为1.0并patience设置为零:

EarlyStopping(monitor='acc', baseline=1.0, patience=0)  # use 'val_acc' instead to monitor validation accuarcy
Run Code Online (Sandbox Code Playgroud)

  • @Eliyah 对不起!我没有测试该解决方案。我已经更新了我的答案并添加了另一个我已经测试过并保证有效的解决方案。请看一下。 (2认同)
  • 只是提一下,它应该是“val_accuracy”。 (2认同)

JBS*_*rro 5

这个名字baseline有误导性。虽然从下面的源码中不太容易解读,但baseline应该理解为:

当监测值比基线差1patience时,请保持最大epoch 的训练时间更长。如果效果更好,请提高基线并重复。

1即精度较低,损耗较高。


相关(修剪过的)源代码EarlyStopping

self.best = baseline  # in initialization
...
def on_epoch_end(self, epoch, logs=None):
  current = self.get_monitor_value(logs)
  if self.monitor_op(current - self.min_delta, self.best):  # read as `current > self.best` (for accuracy)
     self.best = current
     self.wait = 0
  else:
     self.wait += 1
     if self.wait >= self.patience:
        self.model.stop_training = True
Run Code Online (Sandbox Code Playgroud)

那么你的例子 EarlyStopping(monitor='acc', base_line=1.0, patience=0)意味着:虽然监测值比 1.0 更差(它总是如此),但继续训练 0 epoch(即立即终止)。


如果您想要这些语义: 当监控值比基线差时,请继续训练。patience如果更好的话,继续训练,直到连续 epoch没有任何进展,并且还保留 的所有特征EarlyStopping,我可以这样建议:

class MyEarlyStopping(EarlyStopping):
    def __init__(self, *args, **kw):
        super().__init__(*args, **kw)
        self.baseline_attained = False

    def on_epoch_end(self, epoch, logs=None):
        if not self.baseline_attained:
            current = self.get_monitor_value(logs)
            if current is None:
                return

            if self.monitor_op(current, self.baseline):
                if self.verbose > 0:
                    print('Baseline attained.')
                self.baseline_attained = True
            else:
                return

        super(MyEarlyStopping, self).on_epoch_end(epoch, logs)
Run Code Online (Sandbox Code Playgroud)