我通过以下方式在 Python 中的 tf 2.2.0 中保存了 keras 模型:
model.save('model', save_format='tf')
Run Code Online (Sandbox Code Playgroud)
它在“model”目录中给了我一个saved_model.pb。我想通过 c_api 进行推理,并且使用以下函数的代码: TF_LoadSessionFromSavedModel 工作正常。
int main() {
TF_Graph *Graph = TF_NewGraph();
TF_Status *Status = TF_NewStatus();
TF_SessionOptions *SessionOpts = TF_NewSessionOptions();
TF_Buffer *RunOpts = NULL;
const char *saved_model_dir = "model/";
const char *tags = "serve";
int ntags = 1;
TF_Session *Session = TF_LoadSessionFromSavedModel(SessionOpts, RunOpts, saved_model_dir, &tags, ntags, Graph, NULL, Status);
if (TF_GetCode(Status) == TF_OK)
{
printf("TF_LoadSessionFromSavedModel OK\n");
}
else
{
printf("%s", TF_Message(Status));
}
return 0;
}
Run Code Online (Sandbox Code Playgroud)
但是,如果我想通过 TF_GraphImportGraphDef 直接使用“model”目录中的 saving_model.pb,则会出现“Invalid GraphDef”错误: …