pytorch 中的 apply(fn) 函数如何与没有 return 语句作为参数的函数一起工作?

JaM*_*oin 3 python-3.x pytorch

我对以下代码片段有一些疑问:

>>> def init_weights(m):
        print(m)
        if type(m) == nn.Linear:
            m.weight.data.fill_(1.0)
            print(m.weight)

>>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
>>> net.apply(init_weights)
Run Code Online (Sandbox Code Playgroud)

apply() 是 pytorch.nn 包的一部分。您可以在此包的文档中找到代码。最后的问题: 1. 为什么这个代码示例可以工作,尽管在将它提供给 apply() 时没有在 init_weights() 中添加参数或括号?2. init_weights(m) 函数从哪里得到它的参数 m,当它作为参数提供给函数 apply() 而不带括号和 m 时?

Arm*_*ali 6

我们在以下文档中找到了您问题的答案:torch.nn.Module.apply(fn)

适用fn递归到每个子模块(如返回。孩子()以及个体经营)。典型用途包括初始化模型的参数(另请参阅torch-nn-init)。

  1. 为什么此代码示例有效,尽管在将 init_weights() 提供给 apply() 时没有添加参数或括号?
    • 给定的函数init_weights在调用之前不会被apply调用,正是因为没有括号,而是对 的引用init_weights被赋予给apply,并且apply稍后仅从内部init_weights调用。
  2. 当函数 init_weights(m) 作为参数提供给函数 apply() 而不带括号和 m 时,函数 init_weights(m) 从哪里获取它的参数 m?
    • 它在 内的每个调用中获取其参数apply,并且,正如文档所述,由于方法 call ,它被调用以迭代(在这种情况下)的每个子模块net以及它net自身net.apply(…)