PyTorch 将运算符映射到函数

Tom*_*ale 8 operator-overloading pytorch

PyTorch 的所有运算符是什么,它们的功能等价物是什么?

例如,a @ b等价于a.mm(b)a.matmul(b)

我在寻找运算符 -> 函数映射的规范列表。

我很高兴收到 PyTorch 文档链接作为答案 - 我的 googlefu 无法找到它。

Tom*_*ale 8

Python 文档表Mapping Operators to Functions提供了来自以下方面的规范映射:

运算符 -> __function__()

例如:

Matrix Multiplication        a @ b        matmul(a, b)
Run Code Online (Sandbox Code Playgroud)

在页面的其他地方,您将看到该__matmul__名称作为 的替代名称matmul

PyTorch 的定义__functions__可以在以下位置找到:

您可以在以下位置查找命名函数的文档:

https://pytorch.org/docs/stable/torch.html?#torch.<FUNCTION-NAME>
Run Code Online (Sandbox Code Playgroud)


blu*_*nox 6

这定义了 0.3.1 的张量运算(它还包含其他运算符的定义): https ://pytorch.org/docs/0.3.1/_modules/torch/tensor.html

当前稳定的代码已经重新排列(我猜他们现在在 C 中做了更多的事情),但由于矩阵乘法的行为没有改变,我认为假设这仍然有效是可以的。

请参阅 的定义__matmul__

def __matmul__(self, other):
    if not torch.is_tensor(other):
        return NotImplemented
    return self.matmul(other)
Run Code Online (Sandbox Code Playgroud)

def matmul(self, other):
    r"""Matrix product of two tensors.

    See :func:`torch.matmul`."""
    return torch.matmul(self, other)
Run Code Online (Sandbox Code Playgroud)

该运算符@是随PEP 465引入的,并映射到__matmul__

另请参阅此处:
What is the '@=' symbol for in Python?