如何限制pytorch中的参数范围?

t-s*_*art 3 pytorch

所以通常在pytorch中,模型中的参数没有严格的限制,但是如果我希望它们保持在[0,1]范围内怎么办?有没有办法阻止参数更新超出该范围?

小智 9

一些生成对抗网络(其中一些要求判别器的参数在一定范围内)中使用的一个技巧是在每次梯度更新后限制值。例如:

model = YourPyTorchModule()

for _ in range(epochs):
    loss = ...
    optimizer.step()
    for p in model.parameters():
        p.data.clamp_(-1.0, 1.0)
Run Code Online (Sandbox Code Playgroud)