是否可以用现有图形中的常量替换占位符?

Cro*_*rKZ 8 python tensorflow

我有一个受过训练的模型的冻结图,它有一个tf.placeholder我总是提供相同的值.

我想知道是否有可能替换它tf.constant.如果它是某种方式 - 任何例子将不胜感激!

编辑:以下是代码的外观,以帮助可视化问题

我正在使用预先训练的(由其他人)模型进行推理.模型在本地存储为具有.pb扩展名的冻结图形文件.

代码如下所示:

# load graph
graph = load_graph('frozen.pb')
session = tf.Session(graph=graph)

# Get input and output tensors
images_placeholder = graph.get_tensor_by_name("input:0")
output = graph.get_tensor_by_name("output:0")
phase_train_placeholder = graph.get_tensor_by_name("phase_train:0")

feed_dict = {images_placeholder: images, phase_train_placeholder: False}

result = session.run(output, feed_dict=feed_dict)
Run Code Online (Sandbox Code Playgroud)

问题是我总是phase_train_placeholder: False以我的目的为食,所以我想知道是否有可能消除占位符并用类似的东西替换它tf.constant(False, dtype=bool, shape=[])

Cro*_*rKZ 10

所以我没有设法找到任何正确的方法,但设法以一种黑客的方式,通过重建图形def并替换我需要替换的节点.灵感来自这个代码.

这是代码(超级hacky,使用风险自负):

INPUT_GRAPH_DEF_FILE = 'path/to/file'
OUTPUT_GRAPH_DEF_FILE = 'another/one'

# Get NodeDef of a constant tensor we want to put in place of 
# the placeholder. 
# (There is probably a better way to do this)
example_graph = tf.Graph()
with tf.Session(graph=example_graph):
    c = tf.constant(False, dtype=bool, shape=[], name='phase_train')
    for node in example_graph.as_graph_def().node:
        if node.name == 'phase_train':
            c_def = node

# load our graph
graph = load_graph(INPUT_GRAPH_DEF_FILE)
graph_def = graph.as_graph_def()

# Create new graph, and rebuild it from original one
# replacing phase train node def with constant
new_graph_def = graph_pb2.GraphDef()
for node in graph_def.node:
    if node.name == 'phase_train':
        new_graph_def.node.extend([c_def])
    else:
        new_graph_def.node.extend([copy.deepcopy(node)])

# save new graph
with tf.gfile.GFile(OUTPUT_GRAPH_DEF_FILE, "wb") as f:
    f.write(new_graph_def.SerializeToString())
Run Code Online (Sandbox Code Playgroud)


ban*_*men 5

我最近不得不重写上面的答案。

import tensorflow as tf
import sys
from tensorflow.core.framework import graph_pb2
import copy


INPUT_GRAPH_DEF_FILE = sys.argv[1]
OUTPUT_GRAPH_DEF_FILE = sys.argv[2]

# load our graph
def load_graph(filename):
    graph_def = tf.GraphDef()
    with tf.gfile.FastGFile(filename, 'rb') as f:
        graph_def.ParseFromString(f.read())
    return graph_def
graph_def = load_graph(INPUT_GRAPH_DEF_FILE)

target_node_name = sys.argv[3]
c = tf.constant(False, dtype=bool, shape=[], name=target_node_name)

# Create new graph, and rebuild it from original one
# replacing phase train node def with constant
new_graph_def = graph_pb2.GraphDef()
for node in graph_def.node:
    if node.name == target_node_name:
        new_graph_def.node.extend([c.op.node_def])
    else:
        new_graph_def.node.extend([copy.deepcopy(node)])

# save new graph
with tf.gfile.GFile(OUTPUT_GRAPH_DEF_FILE, "wb") as f:
    f.write(new_graph_def.SerializeToString())
Run Code Online (Sandbox Code Playgroud)

  • 当不需要`c`必须与原始节点具有完全相同的名称时,这很好用。如果需要,则应将“ new_graph_def”创建为非默认图形,并在创建“ c”时将“ with custom_graph.as_default()”使用。如果您不这样做,tensorflow将自动在原始节点的名称后附加一个整数。 (3认同)