Pytorch 线性模块类定义中的常量

cod*_*der 7 pytorch

https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html__constants__class Linear(Module):定义的pytorch是什么?

它的功能是什么,为什么要使用它?

我一直在四处寻找,但没有找到任何文档。请注意,这并不意味着__constants__in torch 脚本。

kHa*_*hit 6

__constants__你所谈论的是,事实上,一个有关TorchScript。您可以git blame 通过在 GitHub 上使用(添加时间和添加者)来确认。例如,对于torch/nn/modules/linear.py,检查它的git blame

TorchScript还提供了一种使用 Python 中定义的常量的方法。这些可用于将超参数硬编码到函数中,或定义通用常量。

-- ScriptModule 的属性可以通过将它们列为类的constants属性的成员来标记为常量:

class Foo(torch.jit.ScriptModule):
    __constants__ = ['a']

    def __init__(self):
        super(Foo, self).__init__(False)
        self.a = 1 + 4

   @torch.jit.script_method
   def forward(self, input):
       return self.a + input
Run Code Online (Sandbox Code Playgroud)