use*_*700 3 python-2.7 tensorflow
我正在尝试从检查点文件中恢复一些变量,如果相同的变量名在当前模型中.
我发现Tensorfow Github有一些方法
所以我想做的是检查检查点文件中的变量名称,has_tensor("variable.name")如下所示,
...
reader = tf.train.NewCheckpointReader(ckpt_path)
for v in tf.trainable_variables():
print v.name
if reader.has_tensor(v.name):
print 'has tensor'
...
Run Code Online (Sandbox Code Playgroud)
但我发现v.name返回变量name和colon+number.例如,我有变量名W_o,b_o然后v.name返回W_o:0, b_o:0.
但是reader.has_tensor()要求name没有colon和number作为W_o, b_o.
我的问题是:如何去除colon并number在变量名,以读取变量结束了吗?
有没有更好的方法来恢复这些变量?
您可以使用string.split()来获取张量名称:
...
reader = tf.train.NewCheckpointReader(ckpt_path)
for v in tf.trainable_variables():
tensor_name = v.name.split(':')[0]
print tensor_name
if reader.has_tensor(tensor_name):
print 'has tensor'
...
Run Code Online (Sandbox Code Playgroud)
接下来,让我用一个例子来说明如何从.cpkt文件中恢复每个可能的变量.首先,让我们保存v2并v3进入tmp.ckpt:
import tensorflow as tf
v1 = tf.Variable(tf.ones([1]), name='v1')
v2 = tf.Variable(2 * tf.ones([1]), name='v2')
v3 = tf.Variable(3 * tf.ones([1]), name='v3')
saver = tf.train.Saver({'v2': v2, 'v3': v3})
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
saver.save(sess, 'tmp.ckpt')
Run Code Online (Sandbox Code Playgroud)
这就是我将如何恢复出现的每个变量(属于一个新图形) tmp.ckpt:
with tf.Graph().as_default():
assert len(tf.trainable_variables()) == 0
v1 = tf.Variable(tf.zeros([1]), name='v1')
v2 = tf.Variable(tf.zeros([1]), name='v2')
reader = tf.train.NewCheckpointReader('tmp.ckpt')
restore_dict = dict()
for v in tf.trainable_variables():
tensor_name = v.name.split(':')[0]
if reader.has_tensor(tensor_name):
print('has tensor ', tensor_name)
restore_dict[tensor_name] = v
saver = tf.train.Saver(restore_dict)
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
saver.restore(sess, 'tmp.ckpt')
print(sess.run([v1, v2])) # prints [array([ 0.], dtype=float32), array([ 2.], dtype=float32)]
Run Code Online (Sandbox Code Playgroud)
此外,您可能希望确保形状和dtypes匹配.