pytorch代码仓库html
pytorch在19年11月份的时候合入了这部分剪枝的代码。pytorch提供一些直接可用的api,用户只须要传入须要剪枝的module实例和须要剪枝的参数名字,系统自动帮助完成剪枝操做,看起来接口挺简单。好比 def random_structured(module, name, amount, dim)git
pytorch支持的几种类型的剪枝策略:

详细分析
- pytorch提供了一个剪枝的抽象基类‘‘class BasePruningMethod(ABC)’,全部剪枝策略都须要继承该基类,并重载部分函数就能够了
- 通常状况下须要重载__init__和compute_mask方法,__call__, apply_mask, apply, prune和remove不须要重载,例如官方提供的RandomUnstructured剪枝方法


- 剪枝的API接口,能够看到支持用户自定义的剪枝mask,接口为custom_from_mask

- API的实现,使用classmethod的方法,剪枝策略的实例化在框架内部完成,不须要用户实例化
-
剪枝的大只过程:github
- 根据用户选择的剪枝API生成对应的策略实例,此时会判断须要作剪枝操做的module上是否已经挂有前向回调函数,没有则生成新的,有了就在老的上面添加,而且生成PruningContainer。从这里能够看出,对于同一个module使用多个剪枝策略时,pytorch经过PruningContainer来对剪枝策略进行管理。PruningContainer自己也是继承自BasePruningMethod。同时设置前向计算的回调,便于后续训练时调用。
- 接着根据用户输入的module和name,找到对应的参数tensor。若是是第一次剪枝,那么须要生成_orig结尾的tensor,而后删除原始的module上的tensor。如name为bias,那么生成bias_orig存起来,而后删除module.bias属性。
- 获取defaultmask,而后调用method.computemask生成当前策略的mask值。生成的mask会被存在特定的缓存module.register_buffer(name + "_mask", mask)。这里的compute_mask多是两种状况:若是只有一个策略,那么调用的时候对应剪枝策略的compute_mask方法,若是一个module有多个剪枝策略组合,那么调用的应该是PruningContainer的compute_mask

4. 执行剪枝,保存剪枝结果到module的属性,注册训练时的剪枝回调函数,剪枝完成。新的mask应用在orig的tensor上面生成新的tensor保存的对应的name属性

pytorch还提供各种一个remove接口,目的是把以前的剪枝结果持久化,具体操做就是删除以前生成的跟剪枝相关的缓存或者是回调hook接口,设置被剪枝的name参数(如bias)为最后一次训练的值。
api
-
本身写一个剪枝策略接口也是能够的:
缓存
- 先写一个剪枝策略类继承BasePruningMethod
- 而后重载基类的compute_mask方法,写本身的计算mask方法
官方完整教程在这里app