Ale*_*eee 7 restore tensorflow
我已经保存了一个模型,现在我试图在两个分支中恢复它,如下所示:
我写了这段代码,它引发了ValueError: The same saveable will be restored with two names
. 如何从同一个变量中恢复两个变量?
restore_variables = {}
for varr in tf.global_variables()
if varr.op.name in checkpoint_variables:
restore_variables[varr.op.name.split("_red")[0]] = varr
restore_variables[varr.op.name.split("_blue")[0]] = varr
init_saver = tf.train.Saver(restore_variables, max_to_keep=0)
Run Code Online (Sandbox Code Playgroud)
在 TF 1.15 上测试
基本上,错误是说它在restore_variables
字典中找到对同一变量的多个引用。修复方法很简单。tf.Variable(varr)
使用以下引用之一创建变量的副本。
我认为可以安全地假设您在这里不是在寻找对同一变量的多个引用,而是在寻找两个单独的变量。(我假设这是因为,如果您想多次使用同一个变量,则可以多次使用单个变量)。
with tf.Session() as sess:
saver.restore(sess, './vars/vars.ckpt-0')
restore_variables = {}
checkpoint_variables=['b']
for varr in tf.global_variables():
if varr.op.name in checkpoint_variables:
restore_variables[varr.op.name.split("_red")[0]] = varr
restore_variables[varr.op.name.split("_blue")[0]] = tf.Variable(varr)
print(restore_variables)
init_saver = tf.train.Saver(restore_variables, max_to_keep=0)
Run Code Online (Sandbox Code Playgroud)
您可以在下面找到使用玩具示例复制问题的完整代码。本质上,我们有两个变量a
,b
并且我们正在创建b_red
和b_blue
变量。
# Saving the variables
import tensorflow as tf
import numpy as np
a = tf.placeholder(shape=[None, 3], dtype=tf.float64)
w1 = tf.Variable(np.random.normal(size=[3,2]), name='a')
out = tf.matmul(a, w1)
w2 = tf.Variable(np.random.normal(size=[2,3]), name='b')
out = tf.matmul(out, w2)
saver = tf.train.Saver([w1, w2])
with tf.Session() as sess:
tf.global_variables_initializer().run()
saved_path = saver.save(sess, './vars/vars.ckpt', global_step=0)
Run Code Online (Sandbox Code Playgroud)
# Restoring the variables
with tf.Session() as sess:
saver.restore(sess, './vars/vars.ckpt-0')
restore_variables = {}
checkpoint_variables=['b']
for varr in tf.global_variables():
if varr.op.name in checkpoint_variables:
restore_variables[varr.op.name+"_red"] = varr
# Fixing the issue: Instead of varr, do tf.Variable(varr)
restore_variables[varr.op.name+"_blue"] = varr
print(restore_variables)
init_saver = tf.train.Saver(restore_variables, max_to_keep=0)
Run Code Online (Sandbox Code Playgroud)
归档时间: |
|
查看次数: |
314 次 |
最近记录: |