调用 super 的 forward() 方法

dkv*_*dkv 13 python super pytorch

调用forward()父方法的最合适方法是Module什么?例如,如果我对nn.Linear模块进行子类化,我可能会执行以下操作

class LinearWithOtherStuff(nn.Linear):
    def forward(self, x):
        y = super(Linear, self).forward(x)
        z = do_other_stuff(y)
        return z
Run Code Online (Sandbox Code Playgroud)

但是,文档说不要forward()直接调用该方法:

尽管需要在此函数中定义前向传递的方法,但应该在之后调用 Module 实例而不是 this,因为前者负责运行已注册的钩子,而后者则默默地忽略它们。

这让我觉得super(Linear, self).forward(x)可能会导致一些意想不到的错误。这是真的还是我误解了继承?

Szy*_*zke 5

太长了;

super().forward(...)即使使用钩子,甚至使用实例中注册的钩子,您也可以自由使用super()

解释

正如此答案 所述,__call__因此注册的挂钩(例如register_forward_hook)将运行。

如果您继承并想重用基类的forward,例如:

import torch


class Parent(torch.nn.Module):
    def forward(self, tensor):
        return tensor + 1


class Child(Parent):
    def forward(self, tensor):
        return super(Child, self).forward(tensor) + 1


module = Child()
# Increment output by 1 so we should get `4`
module.register_forward_hook(lambda module, input, output: output + 1)
print(module(torch.tensor(1))) # and it is 4 indeed
print(module.forward(torch.tensor(1))) # here it is 3 still
Run Code Online (Sandbox Code Playgroud)

如果您调用__call__方法,则完全没问题,forward不会运行挂钩(因此您将得到3如上所示的结果)。

您不太可能希望register_hook在 的实例上这样做super,但让我们考虑这样的示例:

def increment_by_one(module, input, output):
    return output + 1


class Parent(torch.nn.Module):
    def forward(self, tensor):
        return tensor + 1


class Child(Parent):
    def forward(self, tensor):
        # Increment by `1` from Parent
        super().register_forward_hook(increment_by_one)
        return super().forward(tensor) + 1


module = Child()
# Increment output by 1 so we should get `5` in total
module.register_forward_hook(increment_by_one)
print(module(torch.tensor(1)))  # and it is 5 indeed
print(module.forward(torch.tensor(1)))  # here is 3
Run Code Online (Sandbox Code Playgroud)

你使用得很好,甚至钩子也能正常工作(这是使用而不是的super().forward(...)主要思想)。__call__forward

顺便提一句。 调用super().__call__(...)会引发InifiniteRecursion错误。


pro*_*sti 1

M0这是PyTorch 中的最小模块。那里什么也没有(没有其他模块)。他们说的forward()是你不应该直接调用它,而是在实例化模块并执行模块时自动调用它m0()

import torch
import torch.nn as nn

class M0(nn.Module):    
    def __init__(self):
        super().__init__()
    
    def forward(self)->None: 
        print("empty module:forward")
        
# we create a module instance m1
m0 = M0()
m0()
# ??m0.__call__ # has forward() inside
Run Code Online (Sandbox Code Playgroud)

出去:

empty module:forward
Run Code Online (Sandbox Code Playgroud)

如果您想要子模块,您可以聚合它们:

import torch
import torch.nn as nn

class M0(nn.Module):    
    def __init__(self):
        super().__init__()
    
    def forward(self)->None: 
        print("empty module:forward")
        
# we create a module instance m1
m0 = M0()
m0()
# ??m0.__call__ # has forward() inside
Run Code Online (Sandbox Code Playgroud)

出去:

M1(
  (l1): Linear(in_features=10, out_features=100, bias=True)
)
M1:forward
torch.Size([1, 100])
Run Code Online (Sandbox Code Playgroud)

一旦聚合了其他模块,就可以调用forward()来执行它们。forward()将需要输入并返回一些输出。

这个模型最初是用 Lua 编程语言提出的,PyTorch 只是使用了它。


这让我认为 super(Linear, self).forward(x) 可能会导致一些意外的错误

这正是为什么forward()不直接调用来抑制这些意外错误的原因。相反,模块是可调用的,就像我们在示例中所做的那样:

self.l1(x)
Run Code Online (Sandbox Code Playgroud)