同时运行多个预先训练的Tensorflow网络

den*_*nru 5 python tensorflow

我想做的是同时运行多个预先训练好的Tensorflow网.因为每个网络中的一些变量的名称可以是相同的,所以常见的解决方案是在创建网络时使用名称范围.但问题是我训练了这些模型并将训练过的变量保存在几个检查点文件中.在创建网络时使用名称范围后,我无法从检查点文件加载变量.

例如,我训练了一个AlexNet,我想比较两组变量,一组来自纪元10(保存在文件epoch_10.ckpt中),另一组来自纪元50(保存在文件epoch_50中). CKPT).因为这两者是完全相同的网,所以内部变量的名称是相同的.我可以使用创建两个网

with tf.name_scope("net1"):
    net1 = CreateAlexNet()
with tf.name_scope("net2"):
    net2 = CreateAlexNet()
Run Code Online (Sandbox Code Playgroud)

但是,我无法从.ckpt文件加载训练过的变量,因为当我训练这个网时,我没有使用名称范围.即使我在训练网络时可以将名称范围设置为"net1",这也可以防止我加载net2的变量.

我试过了:

with tf.name_scope("net1"):
    mySaver.restore(sess, 'epoch_10.ckpt')
with tf.name_scope("net2"):
    mySaver.restore(sess, 'epoch_50.ckpt')
Run Code Online (Sandbox Code Playgroud)

这不起作用.

解决这个问题的最佳方法是什么?

mrr*_*rry 13

最简单的解决方案是创建不同的会话,为每个模型使用单独的图形:

# Build a graph containing `net1`.
with tf.Graph().as_default() as net1_graph:
  net1 = CreateAlexNet()
  saver1 = tf.train.Saver(...)
sess1 = tf.Session(graph=net1_graph)
saver1.restore(sess1, 'epoch_10.ckpt')

# Build a separate graph containing `net2`.
with tf.Graph().as_default() as net2_graph:
  net2 = CreateAlexNet()
  saver2 = tf.train.Saver(...)
sess2 = tf.Session(graph=net1_graph)
saver2.restore(sess2, 'epoch_50.ckpt')
Run Code Online (Sandbox Code Playgroud)

如果由于某种原因这不起作用,并且您必须使用单个tf.Session(例如,因为您希望在另一个TensorFlow计算中组合来自两个网络的结果),最佳解决方案是:

  1. 正如您现在所做的那样,在名称范围内创建不同的网络
  2. tf.train.Saver为两个网络创建单独的实例,并使用另一个参数重新映射变量名称.

构建的储户,就可以通过一本字典作为var_list参数,映射变量的名字在检查点(即没有名称范围前缀)给tf.Variable你的每个模型创建的对象.

您可以以var_list编程方式构建,并且您应该能够执行以下操作:

with tf.name_scope("net1"):
  net1 = CreateAlexNet()
with tf.name_scope("net2"):
  net2 = CreateAlexNet()

# Strip off the "net1/" prefix to get the names of the variables in the checkpoint.
net1_varlist = {v.name.lstrip("net1/"): v
                for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net1/")}
net1_saver = tf.train.Saver(var_list=net1_varlist)

# Strip off the "net2/" prefix to get the names of the variables in the checkpoint.
net2_varlist = {v.name.lstrip("net2/"): v
                for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net2/")}
net2_saver = tf.train.Saver(var_list=net2_varlist)

# ...
net1_saver.restore(sess, "epoch_10.ckpt")
net2_saver.restore(sess, "epoch_50.ckpt")
Run Code Online (Sandbox Code Playgroud)