从元数据文件导入张量流模型时配置input_map

Sev*_*ess 13 tensorflow

我已经训练了一个DCGAN模型,现在想把它加载到一个库中,通过图像空间优化可视化神经元激活的驱动程序.

以下代码有效,但在进行后续图像分析时,迫使我使用(1,宽度,高度,通道)图像,这是一种痛苦(图书馆对网络输入形状的假设).

# creating TensorFlow session and loading the model
graph = tf.Graph()
sess = tf.InteractiveSession(graph=graph)

new_saver = tf.train.import_meta_graph(model_fn)
new_saver.restore(sess, './')
Run Code Online (Sandbox Code Playgroud)

我想更改input_map,在阅读源代码后,我希望这段代码能够正常工作:

graph = tf.Graph()
sess = tf.InteractiveSession(graph=graph)

t_input = tf.placeholder(np.float32, name='images') # define the input tensor
t_preprocessed = tf.expand_dims(t_input, 0)

new_saver = tf.train.import_meta_graph(model_fn, input_map={'images': t_input})
new_saver.restore(sess, './')
Run Code Online (Sandbox Code Playgroud)

但是得到了一个错误:

ValueError:tf.import_graph_def()要求使用非空nameif input_map.

当堆栈到达tf.import_graph_def()name字段时设置为import_scope,所以我尝试了以下操作:

graph = tf.Graph()
sess = tf.InteractiveSession(graph=graph)

t_input = tf.placeholder(np.float32, name='images') # define the input tensor
t_preprocessed = tf.expand_dims(t_input, 0)

new_saver = tf.train.import_meta_graph(model_fn, input_map={'images': t_input}, import_scope='import')
new_saver.restore(sess, './')
Run Code Online (Sandbox Code Playgroud)

这让我了解了以下内容KeyError:

KeyError:"名称'gradients/discriminator/minibatch/map/while/TensorArrayWrite/TensorArrayWriteV3_grad/TensorArrayReadV3/RefEnter:0'指的是一个不存在的Tensor.操作',gradients/discriminator/minibatch/map/while/TensorArrayWrite/TensorArrayWriteV3_grad/TensorArrayReadV3/RefEnter',在图表中不存在."

如果我设置'import_scope',我会得到相同的错误,无论我是否设置'input_map'.

我不知道从哪里开始.

Ish*_*nal 5

在新版本的tensorflow> = 1.2.0中,以下步骤可以正常工作。

t_input = tf.placeholder(np.float32, shape=[None, width, height, channels], name='new_input') # define the input tensor

# here you need to give the name of the original model input placeholder name
# For example if the model has input as; input_original=  tf.placeholder(tf.float32, shape=(1, width, height, channels, name='original_placeholder_name'))
new_saver = tf.train.import_meta_graph(/path/to/checkpoint_file.meta, input_map={'original_placeholder_name:0':  t_input})
new_saver.restore(sess, '/path/to/checkpointfile')
Run Code Online (Sandbox Code Playgroud)