Zer*_*mXZ 9 python neural-network deep-learning activation-function pytorch
我在 Pytorch 中实现自定义激活函数时遇到问题,例如 Swish。我应该如何在 Pytorch 中实现和使用自定义激活函数?
pat*_*_ai 14
根据您要查找的内容,有四种可能性。你需要问自己两个问题:
Q1)你的激活函数有可学习的参数吗?
如果是,则您无法选择将激活函数创建为nn.Module
类,因为您需要存储这些权重。
如果没有,您可以随意创建一个普通函数或一个类,具体取决于您方便的内容。
Q2)你的激活函数可以表示为现有 PyTorch 函数的组合吗?
如果是,您可以简单地将其编写为现有 PyTorch 函数的组合,而无需创建backward
定义渐变的函数。
如果没有,您将需要手动编写渐变。
示例 1:Swish 函数
swish 函数f(x) = x * sigmoid(x)
没有任何学习权重,可以完全使用现有的 PyTorch 函数编写,因此您可以简单地将其定义为一个函数:
def swish(x):
return x * torch.sigmoid(x)
Run Code Online (Sandbox Code Playgroud)
然后像您一样使用它torch.relu
或任何其他激活功能。
示例 2:学习斜率的 Swish
在这种情况下,您有一个学习参数,即斜率,因此您需要对其进行分类。
class LearnedSwish(nn.Module):
def __init__(self, slope = 1):
super().__init__()
self.slope = slope * torch.nn.Parameter(torch.ones(1))
def forward(self, x):
return self.slope * x * torch.sigmoid(x)
Run Code Online (Sandbox Code Playgroud)
示例 3:向后
如果你有一些需要创建自己的渐变函数的东西,你可以看看这个例子:Pytorch:定义自定义函数