Pytorch 自定义激活函数?

Zer*_*mXZ 9 python neural-network deep-learning activation-function pytorch

我在 Pytorch 中实现自定义激活函数时遇到问题,例如 Swish。我应该如何在 Pytorch 中实现和使用自定义激活函数?

pat*_*_ai 14

根据您要查找的内容,有四种可能性。你需要问自己两个问题:

Q1)你的激活函数有可学习的参数吗?

如果,则您无法选择将激活函数创建为nn.Module类,因为您需要存储这些权重。

如果没有,您可以随意创建一个普通函数或一个类,具体取决于您方便的内容。

Q2)你的激活函数可以表示为现有 PyTorch 函数的组合吗?

如果,您可以简单地将其编写为现有 PyTorch 函数的组合,而无需创建backward定义渐变的函数。

如果没有,您将需要手动编写渐变。

示例 1:Swish 函数

swish 函数f(x) = x * sigmoid(x)没有任何学习权重,可以完全使用现有的 PyTorch 函数编写,因此您可以简单地将其定义为一个函数:

def swish(x):
    return x * torch.sigmoid(x)
Run Code Online (Sandbox Code Playgroud)

然后像您一样使用它torch.relu或任何其他激活功能。

示例 2:学习斜率的 Swish

在这种情况下,您有一个学习参数,即斜率,因此您需要对其进行分类。

class LearnedSwish(nn.Module):
    def __init__(self, slope = 1):
        super().__init__()
        self.slope = slope * torch.nn.Parameter(torch.ones(1))

    def forward(self, x):
        return self.slope * x * torch.sigmoid(x)
Run Code Online (Sandbox Code Playgroud)

示例 3:向后

如果你有一些需要创建自己的渐变函数的东西,你可以看看这个例子:Pytorch:定义自定义函数