PyTorchweak_script_method装饰器

Jos*_*r98 3 python pytorch

我在对Word2Vec和PyTorch的介绍中遇到了一些我不太熟悉的代码。我以前从未见过这种类型的代码结构。

>>> import torch
>>> from torch import nn

>>> # an Embedding module containing 10 tensors of size 3
>>> embedding = nn.Embedding(10, 3)
>>> # a batch of 2 samples of 4 indices each
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
>>> embedding(input)

tensor([[[-0.0251, -1.6902,  0.7172],
         [-0.6431,  0.0748,  0.6969],
         [ 1.4970,  1.3448, -0.9685],
         [-0.3677, -2.7265, -0.1685]],

        [[ 1.4970,  1.3448, -0.9685],
         [ 0.4362, -0.4004,  0.9400],
         [-0.6431,  0.0748,  0.6969],
         [ 0.9124, -2.3616,  1.1151]]])
Run Code Online (Sandbox Code Playgroud)

我对以下代码行有些困惑。

>>> embedding(input)
Run Code Online (Sandbox Code Playgroud)

过去我可能无意中忽略了这种语法,但是我不记得以前曾经将变量传递给类实例吗?请参阅PyTorch 文档Class Embedding()定义的位置,是否通过装饰器@weak_script_method包装启用了此行为def forward()?下面的代码表明可能是这种情况?

>>> torch.manual_seed(2)
>>> torch.eq(embedding(input), embedding.forward(input)).all()

tensor(1, dtype=torch.uint8)
Run Code Online (Sandbox Code Playgroud)

为什么@weak_script_method在这种情况下最好使用装饰器?

Jat*_*aki 5

不,@weak_script_method与它无关。embedding(input)遵循Python函数调用语法,该语法可以与“传统”函数以及定义__call__(self, *args, **kwargs)magic函数的对象一起使用。所以这段代码

class Greeter:
    def __init__(self, name):
        self.name = name

    def __call__(self, name):
        print('Hello to ' + name + ' from ' + self.name + '!')

greeter = Greeter('Jatentaki')
greeter('EBB')
Run Code Online (Sandbox Code Playgroud)

将导致Hello to EBB from Jatentaki!打印到标准输出。同样,Embedding您通过告诉它应该包含多少个嵌入物,它们的维数是多少来构造它,然后在构造它之后,可以像调用函数一样调用它,以检索嵌入的所需部分。

__call__nn.Embedding源代码中看不到的原因是它的子类nn.Module,它提供了一个自动__call__实现,该实现在执行forward之前和之后都委派并调用一些额外的东西(请参阅文档)。因此,通话module_instance(arguments)大致等于通话module_instance.forward(arguments)

@weak_script_method装饰有小,用它做。它与jit兼容性有关,@weak_script_method@script_method为PyTorch内部使用而设计的一种变体- 如果您想使用它,则给您的唯一消息应该nn.Embedding是与兼容jit