Che*_* Wu 7 warnings lstm tensorflow
我正在使用 tensorflow2.4,并且是 tensorflow 的新手
这是代码
model = Sequential()
model.add(LSTM(32, input_shape=(X_train.shape[1:])))
model.add(Dropout(0.2))
model.add(Dense(1, activation='linear'))
model.compile(optimizer='rmsprop', loss='mean_absolute_error', metrics='mae')
model.summary()
save_weights_at = 'basic_lstm_model'
save_best = ModelCheckpoint(save_weights_at, monitor='val_loss', verbose=0,
save_best_only=True, save_weights_only=False, mode='min',
period=1)
history = model.fit(x=X_train, y=y_train, batch_size=16, epochs=20,
verbose=1, callbacks=[save_best], validation_data=(X_val, y_val),
shuffle=True)
Run Code Online (Sandbox Code Playgroud)
你知道我为什么会收到这个警告吗?
Ach*_*age 20
我认为这个警告可以被安全地忽略,因为即使在tensorflow给出的教程中你也可以找到相同的警告。在保存自定义模型(例如图神经网络)时,我经常看到此警告。只要您不想访问那些不可调用的函数,您就应该可以继续。
但是,如果您对这一大段文本感到恼火,可以通过在代码顶部添加以下内容来抑制此警告。
import absl.logging
absl.logging.set_verbosity(absl.logging.ERROR)
Run Code Online (Sandbox Code Playgroud)
小智 11
此警告告诉您的是,您正在模型架构中使用自定义层和损失。
回调ModelCheckpoint会在每个时期后保存您的模型,并实现较低的验证损失。模型可以保存为 HDF5 格式或 SavedModel 格式(默认,特定于 TensorFlow 和 Keras)。您在此处使用 SavedModel 格式,因为您没有明确指定.h5扩展名。
每次新纪元达到较低的验证损失时,您的模型都会自动保存,但不会跟踪您的自定义对象(层和损失)。顺便说一句,这就是为什么仅在几个训练周期后才提示警告的原因。
如果没有跟踪的自定义对象,您将无法使用 成功重新加载模型keras.models.load_model()。
如果您不打算将来重新加载最佳模型,则可以安全地忽略此警告。无论如何,训练后您仍然可以在当前本地环境中使用您的最佳模型。
以 H5 格式保存模型似乎对我有用。
model.save(filepath, save_format="h5")
Run Code Online (Sandbox Code Playgroud)
以下是如何将 H5 与模型检查点结合使用(我还没有对此进行广泛测试,买者自负!)
from tensorflow.keras.callbacks import ModelCheckpoint
class ModelCheckpointH5(ModelCheckpoint):
# There is a bug saving models in TF 2.4
# https://github.com/tensorflow/tensorflow/issues/47479
# This forces the h5 format for saving
def __init__(self,
filepath,
monitor='val_loss',
verbose=0,
save_best_only=False,
save_weights_only=False,
mode='auto',
save_freq='epoch',
options=None,
**kwargs):
super(ModelCheckpointH5, self).__init__(filepath,
monitor='val_loss',
verbose=0,
save_best_only=False,
save_weights_only=False,
mode='auto',
save_freq='epoch',
options=None,
**kwargs)
def _save_model(self, epoch, logs):
from tensorflow.python.keras.utils import tf_utils
logs = logs or {}
if isinstance(self.save_freq,
int) or self.epochs_since_last_save >= self.period:
# Block only when saving interval is reached.
logs = tf_utils.to_numpy_or_python_type(logs)
self.epochs_since_last_save = 0
filepath = self._get_file_path(epoch, logs)
try:
if self.save_best_only:
current = logs.get(self.monitor)
if current is None:
logging.warning('Can save best model only with %s available, '
'skipping.', self.monitor)
else:
if self.monitor_op(current, self.best):
if self.verbose > 0:
print('\nEpoch %05d: %s improved from %0.5f to %0.5f,'
' saving model to %s' % (epoch + 1, self.monitor,
self.best, current, filepath))
self.best = current
if self.save_weights_only:
self.model.save_weights(
filepath, overwrite=True, options=self._options)
else:
self.model.save(filepath, overwrite=True, options=self._options,save_format="h5") # NK edited here
else:
if self.verbose > 0:
print('\nEpoch %05d: %s did not improve from %0.5f' %
(epoch + 1, self.monitor, self.best))
else:
if self.verbose > 0:
print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
if self.save_weights_only:
self.model.save_weights(
filepath, overwrite=True, options=self._options)
else:
self.model.save(filepath, overwrite=True, options=self._options,save_format="h5") # NK edited here
self._maybe_remove_file()
except IOError as e:
# `e.errno` appears to be `None` so checking the content of `e.args[0]`.
if 'is a directory' in six.ensure_str(e.args[0]).lower():
raise IOError('Please specify a non-directory filepath for '
'ModelCheckpoint. Filepath used is an existing '
'directory: {}'.format(filepath))
# Re-throw the error for any other causes.
raise
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
4214 次 |
| 最近记录: |