如何将关键字参数传递给前转挂钩使用的前转?

alv*_*vas 7 python callback pytorch tensor

给定nn.Module带有预钩的割炬,例如

import torch
import torch.nn as nn

class NeoEmbeddings(nn.Embedding):
    def __init__(self, num_embeddings:int, embedding_dim:int, padding_idx=-1):
        super().__init__(num_embeddings, embedding_dim, padding_idx)
        self.register_forward_pre_hook(self.neo_genesis)

    @staticmethod
    def neo_genesis(self, input, higgs_bosson=0):
        if higgs_bosson:
            input = input + higgs_bosson
        return input
Run Code Online (Sandbox Code Playgroud)

在进入实际forward()函数之前,可以让输入张量经过一些操作,例如

>>> x = NeoEmbeddings(10, 5, 1)
>>> x.forward(torch.tensor([0,2,5,8]))
tensor([[-1.6449,  0.5832, -0.0165, -1.3329,  0.6878],
        [-0.3262,  0.5844,  0.6917,  0.1268,  2.1363],
        [ 1.0772,  0.1748, -0.7131,  0.7405,  1.5733],
        [ 0.7651,  0.4619,  0.4388, -0.2752, -0.3018]],
       grad_fn=<EmbeddingBackward>)

>>> print(x._forward_pre_hooks)
OrderedDict([(25, <function NeoEmbeddings.neo_genesis at 0x1208d10d0>)])
Run Code Online (Sandbox Code Playgroud)

我们如何传递前向挂钩需要但默认函数不接受的参数(*args**kwargsforward()

如果不修改/覆盖该forward()功能,则不可能:

>>> x = NeoEmbeddings(10, 5, 1)
>>> x.forward(torch.tensor([0,2,5,8]), higgs_bosson=2)

----------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-102-8705a40a3cc2> in <module>
      1 x = NeoEmbeddings(10, 5, 1)
----> 2 x.forward(torch.tensor([0,2,5,8]), higgs_bosson=2)

TypeError: forward() got an unexpected keyword argument 'higgs_bosson'
Run Code Online (Sandbox Code Playgroud)

Szy*_*zke 3

Torchscript 不兼容(截至1.2.0

首先,您的示例torch.nn.Module有一些小错误(可能是偶然的)。

其次,您可以传递任何内容来转发,并且register_forward_pre_hook只会获取将传递给您的其他参数torch.nn.Module(无论是层、模型还是其他任何东西)。您确实无法在不修改调用的情况下做到这一点forward,但为什么要避免这种情况呢?您可以简单地将参数转发给基函数,如下所示:

import torch


class NeoEmbeddings(torch.nn.Embedding):
    def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx=-1):
        super().__init__(num_embeddings, embedding_dim, padding_idx)
        self.register_forward_pre_hook(NeoEmbeddings.neo_genesis)

    # First argument should be named something like module, as that's what 
    # you are registering this hook to
    @staticmethod
    def neo_genesis(module, inputs):  # No need for self as first argument
        net_input, higgs_bosson = inputs  # Simply unpack tuple here
        return net_input

    def forward(self, inputs, higgs_bosson):
        # Do whatever you want here with both arguments, you can ignore 
        # higgs_bosson if it's only needed in the hook as done here
        return super().forward(inputs)


if __name__ == "__main__":
    x = NeoEmbeddings(10, 5, 1)
    # You should call () instead of forward so the hooks register appropriately
    print(x(torch.tensor([0, 2, 5, 8]), 1))
Run Code Online (Sandbox Code Playgroud)

你不能以更简洁的方式做到这一点,但限制是基类forward方法,而不是钩子本身(而且我不希望它更简洁,因为它会变得不可读IMO)。

兼容火炬脚本

如果您想使用 torchscript (在 上测试1.2.0),您可以使用组合而不是继承。您只需更改两行,您的代码可能如下所示:

import torch

# Inherit from Module and register embedding as submodule
class NeoEmbeddings(torch.nn.Module):
    def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx=-1):
        super().__init__()
        # Just use it as a container inside your own class
        self._embedding = torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx)
        self.register_forward_pre_hook(NeoEmbeddings.neo_genesis)

    @staticmethod
    def neo_genesis(module, inputs):
        net_input, higgs_bosson = inputs
        return net_input

    def forward(self, inputs: torch.Tensor, higgs_bosson: torch.Tensor):
        return self._embedding(inputs)


if __name__ == "__main__":
    x = torch.jit.script(NeoEmbeddings(10, 5, 1))
    # All arguments must be tensors in torchscript
    print(x(torch.tensor([0, 2, 5, 8]), torch.tensor([1])))
Run Code Online (Sandbox Code Playgroud)