MSD*_*aul 8 python pruning conv-neural-network pytorch
如何修剪小于阈值的 CNN(卷积神经网络)模型的权重(让我们考虑修剪所有 <= 1 的权重)。
对于在 pytorch 中以 .pth 格式保存的权重文件,我们如何实现?
Szy*_*zke 15
PyTorch 因为1.4.0
提供了开箱即用的模型修剪,请参阅官方教程。
由于目前threshold
在 PyTorch 中没有修剪方法,你必须自己实现它,尽管一旦你有了整体的想法,这会很容易。
下面是执行修剪的代码:
from torch.nn.utils import prune
class ThresholdPruning(prune.BasePruningMethod):
PRUNING_TYPE = "unstructured"
def __init__(self, threshold):
self.threshold = threshold
def compute_mask(self, tensor, default_mask):
return torch.abs(tensor) > self.threshold
Run Code Online (Sandbox Code Playgroud)
解释:
PRUNING_TYPE
可以是global
, structured
, 之一unstructured
。global
作用于整个模块(例如去除20%
最小值的权重),structured
作用于整个通道/模块。我们需要unstructured
修改特定参数张量中的每个连接(比如weight
或bias
)__init__
- 传递任何你想要或需要让它工作的东西,正常的东西compute_mask
- 用于修剪特定张量的掩码。在我们的例子中,低于阈值的所有参数都应该为零。我用绝对值做了它,因为它更有意义。default_mask
此处不需要,但保留为命名参数,因为 API 需要 atm。此外,继承自prune.BasePruningMethod
定义的方法以将掩码应用于每个参数,使修剪永久等。有关更多信息,请参阅基类文档。
没什么特别的,你可以在这里放任何你想要的东西:
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.first = torch.nn.Linear(50, 30)
self.second = torch.nn.Linear(30, 10)
def forward(self, inputs):
return self.second(torch.relu(self.first(inputs)))
module = MyModule()
Run Code Online (Sandbox Code Playgroud)
module = torch.load('checkpoint.pth')
如果需要,您也可以通过加载模块,没关系。
我们应该定义我们模块的哪个参数(以及它是weight
或bias
)应该被修剪,像这样:
parameters_to_prune = ((module.first, "weight"), (module.second, "weight"))
Run Code Online (Sandbox Code Playgroud)
现在,我们可以将global
我们的unstructured
修剪应用于所有定义的parameters
(threshold
传递kwarg
给__init__
of ThresholdPruning
):
prune.global_unstructured(
parameters_to_prune, pruning_method=ThresholdPruning, threshold=0.1
)
Run Code Online (Sandbox Code Playgroud)
weight
属性要查看效果,first
只需使用以下命令检查子模块的权重:
print(module.first.weight)
Run Code Online (Sandbox Code Playgroud)
这是一个应用了我们修剪技术的权重,但请注意它不再是一个torch.nn.Parameter
!现在它只是我们模型的一个属性,因此它目前不会参与训练或评估。
weight_mask
我们可以通过检查创建的掩码module.first.weight_mask
来查看一切是否正确(在这种情况下它将是二进制的)。
weight_orig
应用修剪会创建一个名为 的torch.nn.Parameter
具有原始权重的新权重name + _orig
,在这种情况下weight_orig
,让我们看看:
print(module.first.weight_orig)
Run Code Online (Sandbox Code Playgroud)
该参数将在当前训练和评估期间使用!. pruning
通过上述方法应用后,forward_pre_hooks
添加了哪些“切换”原始weight
到weight_orig
.
由于这种方法,您可以在training
或inference
不“破坏”原始权重的任何部分定义和应用修剪。
如果您希望永久应用修剪,只需发出:
prune.remove(module.first, "weight")
Run Code Online (Sandbox Code Playgroud)
现在我们的module.first.weight
is 参数再次被适当修剪条目,module.first.weight_mask
被删除等等module.first.weight_orig
。这可能是你所追求的。
您可以迭代children
以使其永久化:
for child in module.children():
prune.remove(child, "weight")
Run Code Online (Sandbox Code Playgroud)
您可以parameters_to_prune
使用相同的逻辑定义:
parameters_to_prune = [(child, "weight") for child in module.children()]
Run Code Online (Sandbox Code Playgroud)
或者,如果您只想convolution
修剪图层(或其他任何东西):
parameters_to_prune = [
(child, "weight")
for child in module.children()
if isinstance(child, torch.nn.Conv2d)
]
Run Code Online (Sandbox Code Playgroud)
threshold
,太高了,现在你所有的权重都为零,渲染结果毫无意义)forward
调用期间使用原始权重,除非您想最终更改为修剪版本(简单调用remove
)Shai
)您可以直接处理保存在以下位置的值state_dict
:
sd = torch.load('saved_weights.pth') # load the state dicd
for k in sd.keys():
if not 'weight' in k:
continue # skip biases and other saved parameters
w = sd[k]
sd[k] = w * (w > thr) # set to zero weights smaller than thr
torch.save(sd, 'pruned_weights.pth')
Run Code Online (Sandbox Code Playgroud)
归档时间: |
|
查看次数: |
3062 次 |
最近记录: |