理解 C++ 中新 Tensorflow 运算符的定义

ric*_*cvo 5 operators tensorflow

我正在尝试遵循在 tensorflow 中定义新运算符的官方指南。https://www.tensorflow.org/extend/adding_an_op

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"

using namespace tensorflow;

REGISTER_OP("ZeroOut")
    .Input("to_zero: int32")
    .Output("zeroed: int32")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c){
      c->set_output(0, c->input(0));
      return Status::OK();
    });
Run Code Online (Sandbox Code Playgroud)

但是我找不到这段代码的逐行解释,特别是我不明白 .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) 的作用是什么及其语法。另外我我对 InferenceContext 感到困惑,我猜这是一种连续传递任何数组元素的方法.. 我在任何地方都找不到明确的定义,也许我找错了地方,有人可以帮助我解释或参考?我想深入了解这段代码在幕后做了什么。

Pet*_*den 2

您注意到这里关于形状推断函数的部分了吗? https://www.tensorflow.org/extend/adding_an_op#shape_functions_in_c

其中对 ShapeInferenceContext 类以及编写自己的函数的机制进行了大量讨论。如果这不包括您感兴趣的内容,您能否提供更多详细信息?