在包含“while_loop”的提取的 Tensorflow 子图中计算梯度

Tri*_*een 5 python deep-learning tensorflow

在一些深学习工作流,它是有用的训练模式,提取出来使用其图形的tf.graph_util.convert_variables_to_constantstf.graph_util.extract_sub_graph所以训练相关的张量被排除在外,然后连接所提取的子图,以其它模型(一个或多个)通过tf.import_graph_def。通过这种方式,经过训练的模型可以作为更大设置中的构建块。

通常,我们希望通过新的复合模型进行反向传播,以便对其进行微调、优化输入等。

但是,似乎无法通过while_loop导入图中的tensorflow 操作定义梯度,因为它依赖于“外部上下文”,即添加到元图集合中的对象(参见TF 问题 #7404)。稍微修改这个 Github 问题中的例子,这是我想要做的一个例子:

import tensorflow as tf
g1=tf.Graph()
sess1=tf.Session(graph=g1)
with g1.as_default():
    with sess1.as_default():
        i=tf.constant(0, name="input")
        out=tf.while_loop(lambda i: tf.less(i,5), lambda i: [tf.add(i,1)], [i], name="output")
        loss=tf.square(out,name='loss')
        graph_def = tf.graph_util.convert_variables_to_constants(sess1,g1.as_graph_def(),['output/Exit'])

g2 = tf.Graph()
with g2.as_default():
    tf.import_graph_def(graph_def,name='')
    i_imported = g2.get_tensor_by_name("input:0")
    out_imported = g2.get_tensor_by_name("output/Exit:0")
    tf.gradients(out_imported, i_imported)
Run Code Online (Sandbox Code Playgroud)

最后一行引发AttributeError: 'NoneType' object has no attribute 'outer_context'错误。

Tensorflow的解决这个问题是使用tf.train.export_meta_graphtf.train.import_meta_graph这样的外部背景下被复制,但这份整个图形,没有编辑工作。在这种最小情况下,不会删除“损失”张量。

我尝试将缺少的上下文复制到新图表中:

g2.add_to_collection('while_context',g1.get_collection('while_context'))
Run Code Online (Sandbox Code Playgroud)

但这并不能解决问题。

有没有办法克服这个限制,或者它是一个无法修复的 Tensorflow 设计缺陷?