TF保存/恢复图在tf.GraphDef.ParseFromString()失败

rgr*_*rgr 6 tensorflow

基于这个转换训练的张量流模型到protobuf我试图保存/恢复TF图没有成功.

这是救星:

with tf.Graph().as_default():
    variable_node = tf.Variable(1.0, name="variable_node")
    output_node = tf.mul(variable_node, 2.0, name="output_node")
    sess = tf.Session()
    init = tf.initialize_all_variables()
    sess.run(init)
    output = sess.run(output_node)
    tf.train.write_graph(sess.graph.as_graph_def(), summ_dir, 'model_00_g.pbtxt', as_text=True)
    #self.assertNear(2.0, output, 0.00001)
    saver = tf.train.Saver()
    saver.save(sess, saver_path)
Run Code Online (Sandbox Code Playgroud)

它产生model_00_g.pbtxt了文本图形描述.几乎从freeze_graph_test.py复制粘贴.

这是读者:

with tf.Session() as sess:

    with tf.Graph().as_default():
        graph_def = tf.GraphDef()
        graph_path = '/mnt/code/test_00/log/2016-02-11.22-37-46/model_00_g.pbtxt'
        with open(graph_path, "rb") as f:
            proto_b = f.read()
            #print proto_b   # -> I can see it
            graph_def.ParseFromString(proto_b) # no luck..
            _ = tf.import_graph_def(graph_def, name="")

    print sess.graph_def
Run Code Online (Sandbox Code Playgroud)

哪个失败graph_def.ParseFromString()DecodeError: Tag had invalid wire type.

我在码头集装箱上b.gcr.io/tensorflow/tensorflow:latest-devel,以防它有任何区别.

mrr*_*rry 16

GraphDef.ParseFromString()方法(以及通常,ParseFromString()任何Python protobuf包装器上的方法)需要二进制协议缓冲区格式的字符串.如果传递as_text=Falsetf.train.write_graph(),则文件将采用适当的格式.

否则,您可以执行以下操作来阅读基于文本的格式:

from google.protobuf import text_format
# ...
graph_def = tf.GraphDef()
text_format.Merge(proto_b, graph_def) 
Run Code Online (Sandbox Code Playgroud)