dkv*_*dkv 13 python super pytorch
调用forward()父方法的最合适方法是Module什么?例如,如果我对nn.Linear模块进行子类化,我可能会执行以下操作
class LinearWithOtherStuff(nn.Linear):
def forward(self, x):
y = super(Linear, self).forward(x)
z = do_other_stuff(y)
return z
Run Code Online (Sandbox Code Playgroud)
但是,文档说不要forward()直接调用该方法:
尽管需要在此函数中定义前向传递的方法,但应该在之后调用 Module 实例而不是 this,因为前者负责运行已注册的钩子,而后者则默默地忽略它们。
这让我觉得super(Linear, self).forward(x)可能会导致一些意想不到的错误。这是真的还是我误解了继承?
super().forward(...)即使使用钩子,甚至使用实例中注册的钩子,您也可以自由使用super()。
正如此答案 所述,__call__因此注册的挂钩(例如register_forward_hook)将运行。
如果您继承并想重用基类的forward,例如:
import torch
class Parent(torch.nn.Module):
def forward(self, tensor):
return tensor + 1
class Child(Parent):
def forward(self, tensor):
return super(Child, self).forward(tensor) + 1
module = Child()
# Increment output by 1 so we should get `4`
module.register_forward_hook(lambda module, input, output: output + 1)
print(module(torch.tensor(1))) # and it is 4 indeed
print(module.forward(torch.tensor(1))) # here it is 3 still
Run Code Online (Sandbox Code Playgroud)
如果您调用__call__方法,则完全没问题,forward不会运行挂钩(因此您将得到3如上所示的结果)。
您不太可能希望register_hook在 的实例上这样做super,但让我们考虑这样的示例:
def increment_by_one(module, input, output):
return output + 1
class Parent(torch.nn.Module):
def forward(self, tensor):
return tensor + 1
class Child(Parent):
def forward(self, tensor):
# Increment by `1` from Parent
super().register_forward_hook(increment_by_one)
return super().forward(tensor) + 1
module = Child()
# Increment output by 1 so we should get `5` in total
module.register_forward_hook(increment_by_one)
print(module(torch.tensor(1))) # and it is 5 indeed
print(module.forward(torch.tensor(1))) # here is 3
Run Code Online (Sandbox Code Playgroud)
你使用得很好,甚至钩子也能正常工作(这是使用而不是的super().forward(...)主要思想)。__call__forward
顺便提一句。 调用super().__call__(...)会引发InifiniteRecursion错误。
M0这是PyTorch 中的最小模块。那里什么也没有(没有其他模块)。他们说的forward()是你不应该直接调用它,而是在实例化模块并执行模块时自动调用它m0()
import torch
import torch.nn as nn
class M0(nn.Module):
def __init__(self):
super().__init__()
def forward(self)->None:
print("empty module:forward")
# we create a module instance m1
m0 = M0()
m0()
# ??m0.__call__ # has forward() inside
Run Code Online (Sandbox Code Playgroud)
出去:
empty module:forward
Run Code Online (Sandbox Code Playgroud)
如果您想要子模块,您可以聚合它们:
import torch
import torch.nn as nn
class M0(nn.Module):
def __init__(self):
super().__init__()
def forward(self)->None:
print("empty module:forward")
# we create a module instance m1
m0 = M0()
m0()
# ??m0.__call__ # has forward() inside
Run Code Online (Sandbox Code Playgroud)
出去:
M1(
(l1): Linear(in_features=10, out_features=100, bias=True)
)
M1:forward
torch.Size([1, 100])
Run Code Online (Sandbox Code Playgroud)
一旦聚合了其他模块,就可以调用forward()来执行它们。forward()将需要输入并返回一些输出。
这个模型最初是用 Lua 编程语言提出的,PyTorch 只是使用了它。
这让我认为 super(Linear, self).forward(x) 可能会导致一些意外的错误
这正是为什么forward()不直接调用来抑制这些意外错误的原因。相反,模块是可调用的,就像我们在示例中所做的那样:
self.l1(x)
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
1802 次 |
| 最近记录: |