如何让TensorFlow的'import_graph_def'返回Tensors

oro*_*ome 6 python restore machine-learning tensorflow

如果我尝试导入已保存的TensorFlow图形定义

import tensorflow as tf
from tensorflow.python.platform import gfile

with gfile.FastGFile(FLAGS.model_save_dir.format(log_id) + '/graph.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
x, y, y_ = tf.import_graph_def(graph_def, 
                               return_elements=['data/inputs',
                                                'output/network_activation',
                                                'data/correct_outputs'],
                               name='')
Run Code Online (Sandbox Code Playgroud)

返回的值不是Tensor预期的,而是其他的东西:相反,例如,获取xas

Tensor("data/inputs:0", shape=(?, 784), dtype=float32)
Run Code Online (Sandbox Code Playgroud)

我明白了

name: "data/inputs_1"
op: "Placeholder"
attr {
  key: "dtype"
  value {
    type: DT_FLOAT
  }
}
attr {
  key: "shape"
  value {
    shape {
    }
  }
}
Run Code Online (Sandbox Code Playgroud)

也就是说,而不是得到预期的张量x,x.op.这让我感到困惑,因为文档似乎说我应该得到一个Tensor(虽然有一堆那些让它很难理解).

如何tf.import_graph_def返回Tensor我可以使用的特定s(例如,在输入加载的模型或运行分析时)?

mrr*_*rry 6

名称'data/inputs','output/network_activation''data/correct_outputs'实际上是操作名称。要tf.import_graph_def()返回tf.Tensor对象,您应该将输出索引附加到操作名称,这通常':0'用于单输出操作:

x, y, y_ = tf.import_graph_def(graph_def, 
                               return_elements=['data/inputs:0',
                                                'output/network_activation:0',
                                                'data/correct_outputs:0'],
                               name='')
Run Code Online (Sandbox Code Playgroud)