这个函数如何:input.nn.MSECriterion_updateOutput(self,input,target)工作(在Lua/Torch中)?

lar*_*ars 1 lua function lua-table torch

我有这个功能:

    function MSECriterion:updateOutput(input, target)
        return input.nn.MSECriterion_updateOutput(self, input, target)
    end
Run Code Online (Sandbox Code Playgroud)

现在,

   input.nn.MSECriterion_updateOutput(self, input, target)
Run Code Online (Sandbox Code Playgroud)

返回一个数字.我不知道它是怎么做到的.我已经在调试器中一步步走了,似乎这只是计算一个没有中间步骤的数字.

 input is a Tensor of size 1 (say, -.234). And the 

 nn.MSECriterion_updateOutput(self, input, target) looks like it is just the function MSECriterion:updateOutput(input, target).
Run Code Online (Sandbox Code Playgroud)

我对如何计算数字感到困惑.

我很困惑为什么甚至允许这样做.参数输入是一个张量,它甚至没有任何名为nn.MSE input.nn.MSECriterion_updateOutput的方法.

del*_*eil 5

当您执行require "nn"此负载时init.lua,它依次执行require('libnn').这是torch/nn的C扩展.

如果你看一下init.c,你可以找到luaopen_libnn:这是当调用初始化函数libnn.sorequire-ed.

该函数负责初始化torch/nn的所有部分,包括MSECriterionvia nn_FloatMSECriterion_init(L)和的本机部分nn_DoubleMSECriterion_init(L).

如果你看一下generic/MSECriterion.c你可以找到泛型(即扩展为float和的宏double)初始化函数:

static void nn_(MSECriterion_init)(lua_State *L)
{
  luaT_pushmetatable(L, torch_Tensor);
  luaT_registeratname(L, nn_(MSECriterion__), "nn");
  lua_pop(L,1);
}
Run Code Online (Sandbox Code Playgroud)

这个初始化函数修改的任何元表torch.FloatTensor,并torch.DoubleTensor使其充满了下一堆的功能nn键(见Torch7 Lua的C API有详细介绍).这些功能在以前定义:

static const struct luaL_Reg nn_(MSECriterion__) [] = {
  {"MSECriterion_updateOutput", nn_(MSECriterion_updateOutput)},
  {"MSECriterion_updateGradInput", nn_(MSECriterion_updateGradInput)},
  {NULL, NULL}
};
Run Code Online (Sandbox Code Playgroud)

换句话说,由于其metatable,任何张量都附加了这些功能:

luajit -lnn
> print(torch.Tensor().nn.MSECriterion_updateOutput)
function: 0x40921df8
> print(torch.Tensor().nn.MSECriterion_updateGradInput)
function: 0x40921e20
Run Code Online (Sandbox Code Playgroud)

注意:对于具有C本机实现对应的所有torch/nn模块,此机制是相同的.

因此,您可以在generic/MSECriterion.c上看到input.nn.MSECriterion_updateOutput(self, input, target)调用效果.static int nn_(MSECriterion_updateOutput)(lua_State *L)

此函数计算输入张量之间的均方误差.