我得到了一个冻结(.pb)图形,它具有输入的动态形状(例如“无、无、无、3”或“?x?x?x3”)。我想将这些设置为静态形状(例如“1, 320, 320, 3”),但是我不确定如何将形状更改为输入占位符,以便将更改应用于随后的所有图层。在这种特殊情况下,我没有可用的代码或 ckpt 文件,因此必须在冻结 ( .pb) 图形上进行这项工作。
我已经尝试过什么?
我创建了一个简单的示例代码来制作一个简单的图形并将其保存为可用于测试不同方法的冻结图形。该图的创建方式如下:
import tensorflow as tf
def simple_cnn_graph():
graph = tf.Graph()
with graph.as_default():
input_layer = tf.placeholder(shape=[None, None, None, 3], dtype=tf.float32)
conv1 = tf.layers.conv2d(
inputs=input_layer,
filters=32,
kernel_size=[5, 5],
padding="same",
activation=tf.nn.relu)
pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2, name='pool1')
conv2 = tf.layers.conv2d(
inputs=pool1,
filters=16,
kernel_size=[5, 5],
padding="same",
activation=tf.nn.relu)
pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2, name='pool2')
return graph, pool2
if __name__=='__main__':
graph, output = simple_cnn_graph()
with tf.Session(graph=graph) as sess:
sess.run(tf.global_variables_initializer())
graph_def = tf.graph_util.convert_variables_to_constants(sess, \
tf.get_default_graph().as_graph_def(), [output.name.split(':')[0]])
frozen_file='./frozen.pb'
with open(frozen_file, 'wb') as f:
f.write(graph_def.SerializeToString())
print([n.name for n in graph.as_graph_def().node])
Run Code Online (Sandbox Code Playgroud)
我尝试了两种方法:
1)我已经尝试在以下位置使用 transform_graph 工具:https : //github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/graph_transforms/README.md 使用 strip_unused_nodes,但是它不起作用,因为我没有任何转换为占位符的张量。
2)我在链接中的评论之后取得了一些成功:https :
//github.com/tensorflow/tensorflow/issues/5680#issuecomment-405128390
在那里我能够使用tf.import_graph_def
'sinput_map
来映射一个新的占位符,但是我我正在寻找一个更简单和可推广的解决方案,将来可以应用于任何此类冻结的网络图(例如类似于 transform_graph 的东西)。下面是我使用该tf.import_graph_def
方法的代码
import tensorflow as tf
def load_frozen_graph(frozen_file='frozen.pb'):
graph = tf.Graph()
with graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(frozen_file, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
return graph
graph = load_frozen_graph('./frozen.pb')
print('Tensor shapes before import map')
input_tensor = graph.get_tensor_by_name('Placeholder:0')
print(input_tensor)
output_tensor = graph.get_tensor_by_name('pool2/MaxPool:0')
print(output_tensor)
new_graph = tf.Graph()
with new_graph.as_default():
new_input = tf.placeholder(dtype=tf.float32, shape=[1, 320, 320, 3], name='Placeholder')
tf.import_graph_def(graph.as_graph_def(), name='', input_map={'Placeholder': new_input})
print('Tensor shapes after import map')
input_tensor = new_graph.get_tensor_by_name('Placeholder:0')
print(input_tensor)
output_tensor = new_graph.get_tensor_by_name('pool2/MaxPool:0')
print(output_tensor)
Run Code Online (Sandbox Code Playgroud)
打印输出是:
Tensor shapes before import map
Tensor("Placeholder:0", shape=(?, ?, ?, 3), dtype=float32)
Tensor("pool2/MaxPool:0", shape=(?, ?, ?, 16), dtype=float32)
Tensor shapes after import map
Tensor("Placeholder:0", shape=(1, 320, 320, 3), dtype=float32)
Tensor("pool2/MaxPool:0", shape=(1, 80, 80, 16), dtype=float32)
Run Code Online (Sandbox Code Playgroud)
如果我在上面的代码/帖子中犯了任何错误或理解了关于 tf 形状的任何错误,如果有人能指出我正确的方向或纠正我,我将非常感激。
归档时间: |
|
查看次数: |
1512 次 |
最近记录: |