我一直在训练TensorFlow模型大约一周,偶尔会进行微调.
今天当我试图微调模型时,我得到了错误:
tensorflow.python.framework.errors_impl.NotFoundError: Key conv_classifier/loss/total_loss/avg not found in checkpoint
[[Node: save/RestoreV2_37 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_arg_save/Const_0_0, save/RestoreV2_37/tensor_names, save/RestoreV2_37/shape_and_slices)]]
Run Code Online (Sandbox Code Playgroud)
使用inspect_checkpoint.py我看到检查点文件现在有两个空图层:
...
conv_decode4/ort_weights/Momentum (DT_FLOAT) [7,7,64,64]
loss/cross_entropy/avg (DT_FLOAT) []
loss/total_loss/avg (DT_FLOAT) []
up1/up_filter (DT_FLOAT) [2,2,64,64]
...
Run Code Online (Sandbox Code Playgroud)
我该如何解决这个问题?
解:
为了清楚起见,以下mrry的建议编辑:
code_to_checkpoint_variable_map = {var.op.name: var for var in tf.global_variables()}
for code_variable_name, checkpoint_variable_name in {
"inference/conv_classifier/weight_loss/avg" : "loss/weight_loss/avg",
"inference/conv_classifier/loss/total_loss/avg" : "loss/total_loss/avg",
"inference/conv_classifier/loss/cross_entropy/avg": "loss/cross_entropy/avg",
}.items():
code_to_checkpoint_variable_map[checkpoint_variable_name] = code_to_checkpoint_variable_map[code_variable_name]
del code_to_checkpoint_variable_map[code_variable_name]
saver = tf.train.Saver(code_to_checkpoint_variable_map)
saver.restore(sess, tf.train.latest_checkpoint('./logs'))
Run Code Online (Sandbox Code Playgroud)
幸运的是,看起来您的检查点似乎没有损坏,但是程序中的某些变量已被重命名。我假设"loss/total_loss/avg"应将名为checkpoint的检查点值还原为名为的变量"conv_classifier/loss/total_loss/avg"。您可以通过var_list在创建时传递自定义来解决此问题tf.train.Saver。
name_to_var_map = {var.op.name: var for var in tf.global_variables()}
name_to_var_map["loss/total_loss/avg"] = name_to_var_map[
"conv_classifier/loss/total_loss/avg"]
del name_to_var_map["conv_classifier/loss/total_loss/avg"]
# Depending on how the names have changed, you may also need to do:
# name_to_var_map["loss/cross_entropy/avg"] = name_to_var_map[
# "conv_classifier/loss/cross_entropy/avg"]
# del name_to_var_map["conv_classifier/loss/cross_entropy/avg"]
saver = tf.train.Saver(name_to_var_map)
Run Code Online (Sandbox Code Playgroud)
然后,您可以saver.restore()用来还原模型。或者,您可以使用这种方法来还原模型,并使用默认构造的模型tf.train.Saver将其保存为规范格式。
| 归档时间: |
|
| 查看次数: |
5009 次 |
| 最近记录: |