在PyTorch中调用forward方法与调用模型实例

Shr*_*shi 15 python oop class pytorch

我看过的很多 PyTorch 教程都是这样做的。

定义模型:

class Network(nn.Module):
    def __init__():
        super().__init__()
        self.conv1 = ..
        ... 
    
    def forward(x)
        ...
    ...
Run Code Online (Sandbox Code Playgroud)

一旦网络被实例化 ( net = Network()),教程中的人就会编写net(input_data)而不是net.forward(input_data). 我尝试过net.forward(),它给出了与 相同的结果net()

为什么这是一种常见做法,以及为什么它有效?

Pra*_*kar 18

您应该避免调用Module.forward。\n不同之处在于,所有钩子都在__call__函数see this中分派,因此,如果您调用.forward并在模型中包含钩子,则钩子\xe2\x80\x99不会产生任何效果。

\n

简而言之,当你调用时Module.forward,pytorch hooks不会有任何效果

\n

详细答案可以在这篇 文章中找到

\n