Tensorflow中的C ++等价于python:tf.Graph.get_tensor_by_name()?

boo*_*her 5 c++ tensorflow

Tensorflow中的python tf.Graph.get_tensor_by_name(name)相当于C ++?谢谢!

这是我尝试运行的代码,但是我得到了一个空的output

Status status = NewSession(SessionOptions(), &session); // create new session
ReadBinaryProto(tensorflow::Env::Default(), model, &graph_def); // read Graph
session->Create(graph_def); // add Graph to Tensorflow session 
std::vector<tensorflow::Tensor> output; // create Tensor to store output
std::vector<string> vNames; // vector of names for required graph nodes
vNames.push_back("some_name"); // I checked names and they are presented in loaded Graph

session->Run({}, vNames, {}, &output); // ??? As a result I have empty output
Run Code Online (Sandbox Code Playgroud)

mrr*_*rry 3

您的评论来看,听起来您正在使用 C++ tensorflow::SessionAPI,它将图形表示为GraphDef协议缓冲区。没有等同tf.Graph.get_tensor_by_name()

您无需将类型化tf.Tensor对象传递给,而是传递张量的名称,其形式为,其中与 中的值之一匹配,并且是与要获取的该节点的输出索引相对应的整数。Session::Run()string<NODE NAME>:<N><NODE NAME>NodeDef.nameGraphDef<N>

您问题中的代码看起来大致正确,但我建议两件事:

  1. session->Run()调用返回一个tensorflow::Status值。如果output在调用返回后为空,则几乎可以肯定该调用返回了错误状态以及解释问题的消息。

  2. "some_name"作为要获取的张量的名称传递,但它是节点的名称,而不是张量。此 API 可能要求您显式指定输出索引:尝试将其替换为"some_name:0".