我在对Word2Vec和PyTorch的介绍中遇到了一些我不太熟悉的代码。我以前从未见过这种类型的代码结构。
>>> import torch
>>> from torch import nn
>>> # an Embedding module containing 10 tensors of size 3
>>> embedding = nn.Embedding(10, 3)
>>> # a batch of 2 samples of 4 indices each
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
>>> embedding(input)
tensor([[[-0.0251, -1.6902, 0.7172],
[-0.6431, 0.0748, 0.6969],
[ 1.4970, 1.3448, -0.9685],
[-0.3677, -2.7265, -0.1685]],
[[ 1.4970, 1.3448, -0.9685],
[ 0.4362, -0.4004, 0.9400],
[-0.6431, 0.0748, 0.6969],
[ 0.9124, -2.3616, 1.1151]]])
Run Code Online (Sandbox Code Playgroud)
我对以下代码行有些困惑。
>>> embedding(input)
Run Code Online (Sandbox Code Playgroud)
过去我可能无意中忽略了这种语法,但是我不记得以前曾经将变量传递给类实例吗?请参阅PyTorch 文档中Class Embedding()定义的位置,是否通过装饰器@weak_script_method包装启用了此行为def forward()?下面的代码表明可能是这种情况?
>>> torch.manual_seed(2)
>>> torch.eq(embedding(input), embedding.forward(input)).all()
tensor(1, dtype=torch.uint8)
Run Code Online (Sandbox Code Playgroud)
为什么@weak_script_method在这种情况下最好使用装饰器?
不,@weak_script_method与它无关。embedding(input)遵循Python函数调用语法,该语法可以与“传统”函数以及定义__call__(self, *args, **kwargs)magic函数的对象一起使用。所以这段代码
class Greeter:
def __init__(self, name):
self.name = name
def __call__(self, name):
print('Hello to ' + name + ' from ' + self.name + '!')
greeter = Greeter('Jatentaki')
greeter('EBB')
Run Code Online (Sandbox Code Playgroud)
将导致Hello to EBB from Jatentaki!打印到标准输出。同样,Embedding您通过告诉它应该包含多少个嵌入物,它们的维数是多少来构造它,然后在构造它之后,可以像调用函数一样调用它,以检索嵌入的所需部分。
您__call__在nn.Embedding源代码中看不到的原因是它的子类nn.Module,它提供了一个自动__call__实现,该实现在执行forward之前和之后都委派并调用一些额外的东西(请参阅文档)。因此,通话module_instance(arguments)大致等于通话module_instance.forward(arguments)。
该@weak_script_method装饰有小,用它做。它与jit兼容性有关,@weak_script_method是@script_method为PyTorch内部使用而设计的一种变体- 如果您想使用它,则给您的唯一消息应该nn.Embedding是与兼容jit。
| 归档时间: |
|
| 查看次数: |
807 次 |
| 最近记录: |