PyTorch模块如何做后撑

NoS*_*ult 5 python metaprogramming python-3.x pytorch

在遵循有关扩展PyTorch的说明(添加模块)的同时,我注意到在扩展Module时,我们实际上不必实现向后功能。我们唯一需要做的就是将Function实例应用于前向函数,而PyTorch可以在执行向后道具时自动在Function实例中向后调用一个。这对我来说似乎很神奇,因为我们甚至没有注册我们使用的Function实例。我查看了源代码,但未找到任何相关内容。任何人都可以请我指出所有实际发生的地方吗?

Mac*_*ero 6

不必实施backward()是PyTorch或任何其他DL框架如此有价值的原因。实际上,backward()仅应在需要弄乱网络梯度的非常特殊的情况下(或当您创建无法使用PyTorch的内置函数表示的自定义函数时)执行实现。

PyTorch使用计算图计算后向渐变,该计算图可跟踪您在前向通过期间执行了哪些操作。对Variable隐式执行的任何操作都将在此处注册。然后,需要从调用该变量的位置向后遍历该图,然后应用导数链规则来计算梯度。

PyTorch的“ 关于”页面可以很好地显示图形以及图形的一般工作方式。如果您需要更多详细信息,我还建议您在Google上查找计算图和自动分级机制。

编辑:发生所有这些情况的源代码将在PyTorch的代码库的C部分中,在其中实现了实际图形。经过一番挖掘,我发现了这一点

/// Evaluates the function on the given inputs and returns the result of the
/// function call.
variable_list operator()(const variable_list& inputs) {
    profiler::RecordFunction rec(this);
    if (jit::tracer::isTracingVar(inputs)) {
        return traced_apply(inputs);
    }
    return apply(inputs);
}
Run Code Online (Sandbox Code Playgroud)

因此,在每个函数中,PyTorch首先检查其输入是否需要跟踪,并执行在此实现的trace_apply()。您可以看到正在创建的节点并将其附加到图中:

// Insert a CppOp in the trace.
auto& graph = state->graph;
std::vector<VariableFlags> var_flags;
for(auto & input: inputs) {
    var_flags.push_back(VariableFlags::of(input));
}
auto* this_node = graph->createCppOp(get_shared_ptr(), std::move(var_flags));
// ...
for (auto& input: inputs) {
    this_node->addInput(tracer::getValueTrace(state, input));
}
graph->appendNode(this_node);
Run Code Online (Sandbox Code Playgroud)

我最好的猜测是,每个Function对象在执行时都会注册其自身及其输入(如果需要)。每个非功能性调用(例如variable.dot())都仅顺应于相应的函数,因此这仍然适用。

注意:我不参与PyTorch的开发,也不是其架构方面的专家。任何更正或增加将受到欢迎。