Keras回调继续跳过保存检查点,声称缺少val_acc

xen*_*ity 6 python-3.x keras checkpointing

我将运行一些较大的模型,并尝试中间结果。

因此,我尝试在每个时期之后使用检查点来保存最佳模型。

这是我的代码:

model = Sequential()
model.add(LSTM(700, input_shape=(X_modified.shape[1], X_modified.shape[2]), return_sequences=True))
model.add(Dropout(0.2))
model.add(LSTM(700, return_sequences=True))
model.add(Dropout(0.2))
model.add(LSTM(700))
model.add(Dropout(0.2))
model.add(Dense(Y_modified.shape[1], activation='softmax'))

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

# Save the checkpoint in the /output folder
filepath = "output/text-gen-best.hdf5"

# Keep only a single checkpoint, the best over test accuracy.
checkpoint = ModelCheckpoint(filepath,
                            monitor='val_acc',
                            verbose=1,
                            save_best_only=True,
                            mode='max')
model.fit(X_modified, Y_modified, epochs=100, batch_size=50, callbacks=[checkpoint])
Run Code Online (Sandbox Code Playgroud)

但是在第一个时期之后,我仍然收到警告:

/usr/local/lib/python3.6/site-packages/keras/callbacks.py:432: RuntimeWarning: Can save best model only with val_acc available, skipping.
  'skipping.' % (self.monitor), RuntimeWarning)
Run Code Online (Sandbox Code Playgroud)

要添加metrics=['accuracy']到模型中还存在其他SO问题(例如,在使用预训练的VGG16模型时无法节省重量)的解决方案,但此处仍然存在错误。

Sre*_* TP 12

您正在尝试使用以下代码检查模型

# Save the checkpoint in the /output folder
filepath = "output/text-gen-best.hdf5"

# Keep only a single checkpoint, the best over test accuracy.
checkpoint = ModelCheckpoint(filepath,
                            monitor='val_acc',
                            verbose=1,
                            save_best_only=True,
                            mode='max')
Run Code Online (Sandbox Code Playgroud)

ModelCheckpoint将考虑该参数monitor来决定是否保存模型。在您的代码中是val_acc。因此,如果增加,它将节省重量val_acc

现在在您适合的代码中,

model.fit(X_modified, Y_modified, epochs=100, batch_size=50, callbacks=[checkpoint])
Run Code Online (Sandbox Code Playgroud)

您尚未提供任何验证数据。ModelCheckpoint无法保存权重,因为它没有monitor要检查的参数。

为了根据val_acc您进行检查,您必须提供一些验证数据。

model.fit(X_modified, Y_modified, validation_data=(X_valid, y_valid), epochs=100, batch_size=50, callbacks=[checkpoint])
Run Code Online (Sandbox Code Playgroud)

如果您不想出于任何原因使用验证数据并实施检查点,则必须ModelCheckpoint根据accloss类似方法更改

# Save the checkpoint in the /output folder
filepath = "output/text-gen-best.hdf5"

# Keep only a single checkpoint, the best over test accuracy.
checkpoint = ModelCheckpoint(filepath,
                            monitor='acc',
                            verbose=1,
                            save_best_only=True,
                            mode='max')
Run Code Online (Sandbox Code Playgroud)

请记住,你必须改变modemin,如果你要monitorloss