将Tensorflow Checkpoint文件更新为1.0

jbi*_*ird 2 python tensorflow

我有一个在Tensorflow r0.12中训练的模型,它使用创建了检查点文件SaverV2.我的模型是使用rnn_cellrnn_cell.GRUCell来自的RNN tensorflow.python.ops.自从改为1.0后,这个软件包已根据这个答案转移到core_rnn_cell_impltensorflow.contrib.rnn.python.ops

tf_update.py这里运行文件以将我的文件更新为新版本.但是,自更新以来,我的旧检查点文件不起作用.似乎新GRUCell实现所需的某些变量不存在或具有不同的名称.

示例错误(有132个这样的错误):

2017-02-22 11:36:08.037315: W tensorflow/core/framework/op_kernel.cc:993] Not found: Key NLC/Decoder/DecoderAttnCell/gru_cell/candidate/weights not found in checkpoint
2017-02-22 11:36:08.037382: W tensorflow/core/framework/op_kernel.cc:993] Not found: Key NLC/Decoder/DecoderAttnCell/gru_cell/candidate/weights/Adam not found in checkpoint
2017-02-22 11:36:08.037494: W tensorflow/core/framework/op_kernel.cc:993] Not found: Key NLC/Decoder/DecoderAttnCell/gru_cell/gates/biases/Adam not found in checkpoint
2017-02-22 11:36:08.037499: W tensorflow/core/framework/op_kernel.cc:993] Not found: Key NLC/Decoder/DecoderAttnCell/gru_cell/candidate/weights/Adam_1 not found in checkpoint
2017-02-22 11:36:08.037538: W tensorflow/core/framework/op_kernel.cc:993] Not found: Key NLC/Decoder/DecoderAttnCell/gru_cell/gates/weights not found in checkpoint
2017-02-22 11:36:08.037615: W tensorflow/core/framework/op_kernel.cc:993] Not found: Key NLC/Decoder/DecoderAttnCell/gru_cell/gates/biases not found in checkpoint
2017-02-22 11:36:08.037618: W tensorflow/core/framework/op_kernel.cc:993] Not found: Key NLC/Decoder/DecoderAttnCell/gru_cell/gates/biases/Adam_1 not found in checkpoint
2017-02-22 11:36:08.038098: W tensorflow/core/framework/op_kernel.cc:993] Not found: Key NLC/Decoder/DecoderAttnCell/gru_cell/gates/weights/Adam_1 not found in checkpoint
2017-02-22 11:36:08.038121: W tensorflow/core/framework/op_kernel.cc:993] Not found: Key NLC/Decoder/DecoderAttnCell/gru_cell/gates/weights/Adam not found in checkpoint
2017-02-22 11:36:08.038222: W tensorflow/core/framework/op_kernel.cc:993] Not found: Key NLC/Decoder/DecoderCell0/gru_cell/candidate/biases not found in checkpoint
2017-02-22 11:36:08.038229: W tensorflow/core/framework/op_kernel.cc:993] Not found: Key NLC/Decoder/DecoderCell0/gru_cell/candidate/weights not found in checkpoint
2017-02-22 11:36:08.038233: W tensorflow/core/framework/op_kernel.cc:993] Not found: Key NLC/Decoder/DecoderCell0/gru_cell/candidate/biases/Adam_1 not found in checkpoint
Run Code Online (Sandbox Code Playgroud)

保存/加载工作完美,直到更新.如何将旧的检查点文件更新到r1.0?

如果重要,我使用的是python2.7,当使用CUDA的仅CPU张量流或张量流时会发生同样的错误.

ase*_*lle 5

没有简单的方法可以做到这一点......一种方法是使用get_variable_to_shape_map()

  ckpt_reader = tf.train.NewCheckpointReader(filepath)
  ckpt_vars = ckpt_reader.get_variable_to_shape_map()
Run Code Online (Sandbox Code Playgroud)

这将为您提供已保存检查点中形状的变量名称列表.然后...创建一个从旧名称映射到新名称的字典即

old_to_new={}
old_to_new[old_name] = new_name
Run Code Online (Sandbox Code Playgroud)

然后即时启动救星并恢复那些变量

saver = tf.Saver(old_to_new)
saver.restore(filepath)
Run Code Online (Sandbox Code Playgroud)

祝你好运,希望这会有帮助.