我正在尝试遵循在 tensorflow 中定义新运算符的官方指南。https://www.tensorflow.org/extend/adding_an_op
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
using namespace tensorflow;
REGISTER_OP("ZeroOut")
.Input("to_zero: int32")
.Output("zeroed: int32")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c){
c->set_output(0, c->input(0));
return Status::OK();
});
Run Code Online (Sandbox Code Playgroud)
但是我找不到这段代码的逐行解释,特别是我不明白 .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) 的作用是什么及其语法。另外我我对 InferenceContext 感到困惑,我猜这是一种连续传递任何数组元素的方法.. 我在任何地方都找不到明确的定义,也许我找错了地方,有人可以帮助我解释或参考?我想深入了解这段代码在幕后做了什么。