mda*_*ust 35 python tensorflow
如果你有两个不相交的图,并想要链接它们,转过来:
x = tf.placeholder('float')
y = f(x)
y = tf.placeholder('float')
z = f(y)
Run Code Online (Sandbox Code Playgroud)
进入这个:
x = tf.placeholder('float')
y = f(x)
z = g(y)
Run Code Online (Sandbox Code Playgroud)
有没有办法做到这一点?在某些情况下,它似乎可以使构造更容易.
例如,如果您有一个将输入图像作为a的图形tf.placeholder
,并且想要优化输入图像,那么深层梦想的样式是否有办法用占位符替换tf.variable
节点?或者在构建图表之前你必须考虑到这一点吗?
mrr*_*rry 33
TL; DR:如果你可以将两个计算定义为Python函数,那么你应该这样做.如果不能,TensorFlow中有更多高级功能可以序列化和导入图形,这使您可以组合来自不同来源的图形.
在TensorFlow中执行此操作的一种方法是将不相交的计算构建为单独的tf.Graph
对象,然后使用以下方法将它们转换为序列化协议缓冲区Graph.as_graph_def()
:
with tf.Graph().as_default() as g_1:
input = tf.placeholder(tf.float32, name="input")
y = f(input)
# NOTE: using identity to get a known name for the output tensor.
output = tf.identity(y, name="output")
gdef_1 = g_1.as_graph_def()
with tf.Graph().as_default() as g_2: # NOTE: g_2 not g_1
input = tf.placeholder(tf.float32, name="input")
z = g(input)
output = tf.identity(y, name="output")
gdef_2 = g_2.as_graph_def()
Run Code Online (Sandbox Code Playgroud)
然后你可以编写gdef_1
并gdef_2
进入第三个图表,使用tf.import_graph_def()
:
with tf.Graph().as_default() as g_combined:
x = tf.placeholder(tf.float32, name="")
# Import gdef_1, which performs f(x).
# "input:0" and "output:0" are the names of tensors in gdef_1.
y, = tf.import_graph_def(gdef_1, input_map={"input:0": x},
return_elements=["output:0"])
# Import gdef_2, which performs g(y)
z, = tf.import_graph_def(gdef_2, input_map={"input:0": y},
return_elements=["output:0"]
Run Code Online (Sandbox Code Playgroud)
如果要合并训练后的模型(例如在新模型中重用预先训练的模型的一部分),可以使用Saver
来保存第一个模型的检查点,然后将该模型(全部或部分)还原到另一个模型中。
例如,假设您要w
在模型2中重用模型1的权重,并且还要x
从占位符转换为变量:
with tf.Graph().as_default() as g1:
x = tf.placeholder('float')
w = tf.Variable(1., name="w")
y = x * w
saver = tf.train.Saver()
with tf.Session(graph=g1) as sess:
w.initializer.run()
# train...
saver.save(sess, "my_model1.ckpt")
with tf.Graph().as_default() as g2:
x = tf.Variable(2., name="v")
w = tf.Variable(0., name="w")
z = x + w
restorer = tf.train.Saver([w]) # only restore w
with tf.Session(graph=g2) as sess:
x.initializer.run() # x now needs to be initialized
restorer.restore(sess, "my_model1.ckpt") # restores w=1
print(z.eval()) # prints 3.
Run Code Online (Sandbox Code Playgroud)
事实证明,tf.train.import_meta_graph
将所有其他参数传递给import_scoped_meta_graph
具有该input_map
参数的基础,并在对自身(内部)调用时加以利用import_graph_def
。
它没有记录,并且花了我很多时间才找到它,但是它有效!
归档时间: |
|
查看次数: |
15761 次 |
最近记录: |