tf 如何从同一个变量中恢复两个变量

Ale*_*eee 7 restore tensorflow

我已经保存了一个模型,现在我试图在两个分支中恢复它,如下所示:

在此处输入图片说明

我写了这段代码,它引发了ValueError: The same saveable will be restored with two names. 如何从同一个变量中恢复两个变量?

restore_variables = {}
for varr in tf.global_variables()
    if varr.op.name in checkpoint_variables:
        restore_variables[varr.op.name.split("_red")[0]] = varr           
        restore_variables[varr.op.name.split("_blue")[0]] = varr
init_saver = tf.train.Saver(restore_variables, max_to_keep=0)
Run Code Online (Sandbox Code Playgroud)

thu*_*v89 1

在 TF 1.15 上测试

基本上,错误是说它在restore_variables字典中找到对同一变量的多个引用。修复方法很简单。tf.Variable(varr)使用以下引用之一创建变量的副本。

我认为可以安全地假设您在这里不是在寻找对同一变量的多个引用,而是在寻找两个单独的变量。(我假设这是因为,如果您想多次使用同一个变量,则可以多次使用单个变量)。

with tf.Session() as sess:
    saver.restore(sess, './vars/vars.ckpt-0')
    restore_variables = {}
    checkpoint_variables=['b']
    for varr in tf.global_variables():
        if varr.op.name in checkpoint_variables:
            restore_variables[varr.op.name.split("_red")[0]] = varr           
            restore_variables[varr.op.name.split("_blue")[0]] = tf.Variable(varr)
    print(restore_variables)
    init_saver = tf.train.Saver(restore_variables, max_to_keep=0)
Run Code Online (Sandbox Code Playgroud)

您可以在下面找到使用玩具示例复制问题的完整代码。本质上,我们有两个变量ab并且我们正在创建b_redb_blue变量。

# Saving the variables

import tensorflow as tf
import numpy as np
a = tf.placeholder(shape=[None, 3], dtype=tf.float64)
w1 = tf.Variable(np.random.normal(size=[3,2]), name='a')
out = tf.matmul(a, w1)
w2 = tf.Variable(np.random.normal(size=[2,3]), name='b')
out = tf.matmul(out, w2)

saver = tf.train.Saver([w1, w2])

with tf.Session() as sess:
  tf.global_variables_initializer().run()
  saved_path = saver.save(sess, './vars/vars.ckpt', global_step=0)
Run Code Online (Sandbox Code Playgroud)
# Restoring the variables

with tf.Session() as sess:
    saver.restore(sess, './vars/vars.ckpt-0')
    restore_variables = {}
    checkpoint_variables=['b']
    for varr in tf.global_variables():
        if varr.op.name in checkpoint_variables:
            restore_variables[varr.op.name+"_red"] = varr  
            # Fixing the issue: Instead of varr, do tf.Variable(varr)
            restore_variables[varr.op.name+"_blue"] = varr
    print(restore_variables)
    init_saver = tf.train.Saver(restore_variables, max_to_keep=0)
Run Code Online (Sandbox Code Playgroud)