Tensorflow:如何替换计算图中的节点?

mda*_*ust 35 python tensorflow

如果你有两个不相交的图,并想要链接它们,转过来:

x = tf.placeholder('float')
y = f(x)

y = tf.placeholder('float')
z = f(y)
Run Code Online (Sandbox Code Playgroud)

进入这个:

x = tf.placeholder('float')
y = f(x)
z = g(y)
Run Code Online (Sandbox Code Playgroud)

有没有办法做到这一点?在某些情况下,它似乎可以使构造更容易.

例如,如果您有一个将输入图像作为a的图形tf.placeholder,并且想要优化输入图像,那么深层梦想的样式是否有办法用占位符替换tf.variable节点?或者在构建图表之前你必须考虑到这一点吗?

mrr*_*rry 33

TL; DR:如果你可以将两个计算定义为Python函数,那么你应该这样做.如果不能,TensorFlow中有更多高级功能可以序列化和导入图形,这使您可以组合来自不同来源的图形.

在TensorFlow中执行此操作的一种方法是将不相交的计算构建为单独的tf.Graph对象,然后使用以下方法将它们转换为序列化协议缓冲区Graph.as_graph_def():

with tf.Graph().as_default() as g_1:
  input = tf.placeholder(tf.float32, name="input")
  y = f(input)
  # NOTE: using identity to get a known name for the output tensor.
  output = tf.identity(y, name="output")

gdef_1 = g_1.as_graph_def()

with tf.Graph().as_default() as g_2:  # NOTE: g_2 not g_1       
  input = tf.placeholder(tf.float32, name="input")
  z = g(input)
  output = tf.identity(y, name="output")

gdef_2 = g_2.as_graph_def()
Run Code Online (Sandbox Code Playgroud)

然后你可以编写gdef_1gdef_2进入第三个图表,使用tf.import_graph_def():

with tf.Graph().as_default() as g_combined:
  x = tf.placeholder(tf.float32, name="")

  # Import gdef_1, which performs f(x).
  # "input:0" and "output:0" are the names of tensors in gdef_1.
  y, = tf.import_graph_def(gdef_1, input_map={"input:0": x},
                           return_elements=["output:0"])

  # Import gdef_2, which performs g(y)
  z, = tf.import_graph_def(gdef_2, input_map={"input:0": y},
                           return_elements=["output:0"]
Run Code Online (Sandbox Code Playgroud)

  • 不幸的是,这是正确的.解决方法是对tf.get_default_graph()中的op执行`vars = op.outputs [0].如果op.type =="Variable"则执行get_operations()``然后将`var_list = vars`传递给`minimize()` . (5认同)

Min*_*ark 6

如果要合并训练后的模型(例如在新模型中重用预先训练的模型的一部分),可以使用Saver来保存第一个模型的检查点,然后将该模型(全部或部分)还原到另一个模型中。

例如,假设您要w在模型2中重用模型1的权重,并且还要x从占位符转换为变量:

with tf.Graph().as_default() as g1:
    x = tf.placeholder('float')
    w = tf.Variable(1., name="w")
    y = x * w
    saver = tf.train.Saver()

with tf.Session(graph=g1) as sess:
    w.initializer.run()
    # train...
    saver.save(sess, "my_model1.ckpt")

with tf.Graph().as_default() as g2:
    x = tf.Variable(2., name="v")
    w = tf.Variable(0., name="w")
    z = x + w
    restorer = tf.train.Saver([w]) # only restore w

with tf.Session(graph=g2) as sess:
    x.initializer.run()  # x now needs to be initialized
    restorer.restore(sess, "my_model1.ckpt") # restores w=1
    print(z.eval())  # prints 3.
Run Code Online (Sandbox Code Playgroud)


Jon*_*iev 5

事实证明,tf.train.import_meta_graph将所有其他参数传递给import_scoped_meta_graph具有该input_map参数的基础,并在对自身(内部)调用时加以利用import_graph_def

它没有记录,并且花了我很多时间才找到它,但是它有效!