如何在 PyTorch 中修剪小于阈值的权重?

MSD*_*aul 8 python pruning conv-neural-network pytorch

如何修剪小于阈值的 CNN(卷积神经网络)模型的权重(让我们考虑修剪所有 <= 1 的权重)。

对于在 pytorch 中以 .pth 格式保存的权重文件,我们如何实现?

Szy*_*zke 15

PyTorch 因为1.4.0提供了开箱即用的模型修剪,请参阅官方教程

由于目前threshold在 PyTorch 中没有修剪方法,你必须自己实现它,尽管一旦你有了整体的想法,这会很容易。

阈值修剪方法

下面是执行修剪的代码:

from torch.nn.utils import prune


class ThresholdPruning(prune.BasePruningMethod):
    PRUNING_TYPE = "unstructured"

    def __init__(self, threshold):
        self.threshold = threshold

    def compute_mask(self, tensor, default_mask):
        return torch.abs(tensor) > self.threshold
Run Code Online (Sandbox Code Playgroud)

解释:

  • PRUNING_TYPE可以是global, structured, 之一unstructuredglobal作用于整个模块(例如去除20%最小值的权重),structured作用于整个通道/模块。我们需要unstructured修改特定参数张量中的每个连接(比如weightbias
  • __init__ - 传递任何你想要或需要让它工作的东西,正常的东西
  • compute_mask- 用于修剪特定张量的掩码。在我们的例子中,低于阈值的所有参数都应该为零。我用绝对值做了它,因为它更有意义。default_mask此处不需要,但保留为命名参数,因为 API 需要 atm。

此外,继承自prune.BasePruningMethod定义的方法以将掩码应用于每个参数,使修剪永久等。有关更多信息,请参阅基类文档

示例模块

没什么特别的,你可以在这里放任何你想要的东西:

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.first = torch.nn.Linear(50, 30)
        self.second = torch.nn.Linear(30, 10)

    def forward(self, inputs):
        return self.second(torch.relu(self.first(inputs)))


module = MyModule()
Run Code Online (Sandbox Code Playgroud)

module = torch.load('checkpoint.pth') 如果需要,您也可以通过加载模块,没关系。

修剪模块的参数

我们应该定义我们模块的哪个参数(以及它是weightbias)应该被修剪,像这样:

parameters_to_prune = ((module.first, "weight"), (module.second, "weight"))
Run Code Online (Sandbox Code Playgroud)

现在,我们可以将global我们的unstructured修剪应用于所有定义的parametersthreshold传递kwarg__init__of ThresholdPruning):

prune.global_unstructured(
    parameters_to_prune, pruning_method=ThresholdPruning, threshold=0.1
)
Run Code Online (Sandbox Code Playgroud)

结果

weight 属性

要查看效果,first只需使用以下命令检查子模块的权重:

print(module.first.weight)
Run Code Online (Sandbox Code Playgroud)

这是一个应用了我们修剪技术的权重,但请注意它不再是一个torch.nn.Parameter现在它只是我们模型的一个属性,因此目前不会参与训练或评估

weight_mask

我们可以通过检查创建的掩码module.first.weight_mask来查看一切是否正确(在这种情况下它将是二进制的)。

weight_orig

应用修剪会创建一个名为 的torch.nn.Parameter具有原始权重的新权重name + _orig,在这种情况下weight_orig,让我们看看:

print(module.first.weight_orig)
Run Code Online (Sandbox Code Playgroud)

该参数将在当前训练和评估期间使用!. pruning通过上述方法应用后,forward_pre_hooks添加了哪些“切换”原始weightweight_orig.

由于这种方法,您可以在traininginference不“破坏”原始权重的任何部分定义和应用修剪。

永久应用修剪

如果您希望永久应用修剪,只需发出:

prune.remove(module.first, "weight")
Run Code Online (Sandbox Code Playgroud)

现在我们的module.first.weightis 参数再次被适当修剪条目,module.first.weight_mask被删除等等module.first.weight_orig这可能是你所追求的

您可以迭代children以使其永久化:

for child in module.children():
    prune.remove(child, "weight")
Run Code Online (Sandbox Code Playgroud)

您可以parameters_to_prune使用相同的逻辑定义:

parameters_to_prune = [(child, "weight") for child in module.children()]
Run Code Online (Sandbox Code Playgroud)

或者,如果您只想convolution修剪图层(或其他任何东西):

parameters_to_prune = [
    (child, "weight")
    for child in module.children()
    if isinstance(child, torch.nn.Conv2d)
]
Run Code Online (Sandbox Code Playgroud)

好处

  • 使用“PyTorch 修剪方式”,因此可以更轻松地将您的意图传达给其他程序员
  • 在每个张量的基础上定义修剪,单一职责而不是完成所有事情
  • 仅限于预定义的方式
  • 修剪不是永久性的,因此您可以根据需要从中恢复。模块可以用修剪掩码和原始权重来保存,这样你就可以留出一些空间来恢复最终的错误(例如threshold,太高了,现在你所有的权重都为零,渲染结果毫无意义)
  • forward调用期间使用原始权重,除非您想最终更改为修剪版本(简单调用remove

缺点

  • IMO 修剪 API 可能更清晰
  • 你可以做得更短(由 提供Shai
  • 对于那些不知道 PyTorch“定义”这样的东西的人来说可能会感到困惑(仍然有教程和文档,所以我认为这不是一个主要问题)

  • 如果您对*结构化修剪*(神经元修剪而不是单个权重)感兴趣,您可以考虑外部库,例如 [TorchPruner](https://github.com/marcoancona/TorchPruner/) (2认同)

Sha*_*hai 3

您可以直接处理保存在以下位置的值state_dict

sd = torch.load('saved_weights.pth')  # load the state dicd
for k in sd.keys():
  if not 'weight' in k:
    continue  # skip biases and other saved parameters
  w = sd[k]
  sd[k] = w * (w > thr)  # set to zero weights smaller than thr 
torch.save(sd, 'pruned_weights.pth')
Run Code Online (Sandbox Code Playgroud)