Mar*_*cus 6 python recommendation-engine python-3.x keras tensorflow
我正在为推荐系统(项目推荐)进行多类分类,我目前正在使用sparse_categorical_crossentropy
损失训练我的网络。因此,EarlyStopping
通过监控我的验证损失来执行是合理的,val_loss
例如:
tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10)
Run Code Online (Sandbox Code Playgroud)
它按预期工作。然而,网络(推荐系统)的性能是由 Average-Precision-at-10 来衡量的,并在训练期间作为一个指标进行跟踪,如average_precision_at_k10
. 因此,我还可以使用此指标执行提前停止:
tf.keras.callbacks.EarlyStopping(monitor='average_precision_at_k10', patience=10)
Run Code Online (Sandbox Code Playgroud)
这也按预期工作。
我的问题: 有时验证损失会增加,而 10 的平均精度正在提高,反之亦然。因此,当且仅当两者都在恶化时,我需要监控两者并提前停止。我想做什么:
tf.keras.callbacks.EarlyStopping(monitor=['val_loss', 'average_precision_at_k10'], patience=10)
Run Code Online (Sandbox Code Playgroud)
这显然不起作用。任何想法如何做到这一点?
在指导下在上面Gerry P,我成功创建了自己的自定义 EarlyStopping 回调,并认为我将其发布在这里,以防其他人想要实现类似的东西。
如果验证损失和10 时的 平均精度对于epoch 数没有改善,则执行早期停止。patience
class CustomEarlyStopping(keras.callbacks.Callback):
def __init__(self, patience=0):
super(CustomEarlyStopping, self).__init__()
self.patience = patience
self.best_weights = None
def on_train_begin(self, logs=None):
# The number of epoch it has waited when loss is no longer minimum.
self.wait = 0
# The epoch the training stops at.
self.stopped_epoch = 0
# Initialize the best as infinity.
self.best_v_loss = np.Inf
self.best_map10 = 0
def on_epoch_end(self, epoch, logs=None):
v_loss=logs.get('val_loss')
map10=logs.get('val_average_precision_at_k10')
# If BOTH the validation loss AND map10 does not improve for 'patience' epochs, stop training early.
if np.less(v_loss, self.best_v_loss) and np.greater(map10, self.best_map10):
self.best_v_loss = v_loss
self.best_map10 = map10
self.wait = 0
# Record the best weights if current results is better (less).
self.best_weights = self.model.get_weights()
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
print("Restoring model weights from the end of the best epoch.")
self.model.set_weights(self.best_weights)
def on_train_end(self, logs=None):
if self.stopped_epoch > 0:
print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))
Run Code Online (Sandbox Code Playgroud)
然后将其用作:
model.fit(
x_train,
y_train,
batch_size=64,
steps_per_epoch=5,
epochs=30,
verbose=0,
callbacks=[CustomEarlyStopping(patience=10)],
)
Run Code Online (Sandbox Code Playgroud)
归档时间: |
|
查看次数: |
1455 次 |
最近记录: |