Keras回调EarlyStopping比较培训和验证损失

Tom*_*ini 5 python neural-network keras tensorflow

我在Python Keras中安装了神经网络。

为了避免过度拟合,我想监视训练/验证损失并创建适当的回调,当训练损失远小于验证损失时,该回调将停止计算。

回调的示例是:

callback = [EarlyStopping(monitor='val_loss', value=45, verbose=0, mode='auto')]
Run Code Online (Sandbox Code Playgroud)

与验证损失相比,训练损失过少时,有什么方法可以停止训练?

先感谢您

Nas*_*Ben 6

您可以根据您的目的创建自定义回调类。

我已经创建了一个应该符合您的需求的:

class CustomEarlyStopping(Callback):
    def __init__(self, ratio=0.0,
                 patience=0, verbose=0):
        super(EarlyStopping, self).__init__()

        self.ratio = ratio
        self.patience = patience
        self.verbose = verbose
        self.wait = 0
        self.stopped_epoch = 0
        self.monitor_op = np.greater

    def on_train_begin(self, logs=None):
        self.wait = 0  # Allow instances to be re-used

    def on_epoch_end(self, epoch, logs=None):
        current_val = logs.get('val_loss')
        current_train = logs.get('loss')
        if current_val is None:
            warnings.warn('Early stopping requires %s available!' %
                          (self.monitor), RuntimeWarning)

        # If ratio current_loss / current_val_loss > self.ratio
        if self.monitor_op(np.divide(current_train,current_val),self.ratio):
            self.wait = 0
        else:
            if self.wait >= self.patience:
                self.stopped_epoch = epoch
                self.model.stop_training = True
            self.wait += 1

    def on_train_end(self, logs=None):
        if self.stopped_epoch > 0 and self.verbose > 0:
            print('Epoch %05d: early stopping' % (self.stopped_epoch))
Run Code Online (Sandbox Code Playgroud)

我冒昧地解释说,如果 和 之间的比率train_loss低于validation_loss某个比率阈值,您想停止。该比率参数应介于0.0和之间1.0。然而,1.0这是危险的,因为验证损失和训练损失在训练开始时可能会以不稳定的方式大幅波动。

您可以添加一个耐心参数,它将等待查看阈值的突破是否会持续一定数量的时期。

使用方法例如:

callbacks = [CustomEarlyStopping(ratio=0.5, patience=2, verbose=1), 
            ... Other callbacks ...]
...
model.fit(..., callbacks=callbacks)
Run Code Online (Sandbox Code Playgroud)

在这种情况下,如果训练损失0.5*val_loss在超过 2 个 epoch 中保持低于 200 倍,它将停止。

这对你有帮助吗?