知识蒸馏(Knowledge Distillation)

一、Distilling the Knowledge in a Neural Network

Hinton的文章"Distilling the Knowledge in a Neural Network"首次提出了知识蒸馏(暗知识提取)的概念,经过引入与教师网络(teacher network:复杂、但推理性能优越)相关的软目标(soft-target)做为total loss的一部分,以诱导学生网络(student network:精简、低复杂度)的训练,实现知识迁移(knowledge transfer)。python

如上图所示,教师网络(左侧)的预测输出除以温度参数(Temperature)以后、再作softmax变换,能够得到软化的几率分布(软目标),数值介于0~1之间,取值分布较为缓和。Temperature数值越大,分布越缓和;而Temperature数值减少,容易放大错误分类的几率,引入没必要要的噪声。针对较困难的分类或检测任务,Temperature一般取1,确保教师网络中正确预测的贡献。硬目标则是样本的真实标注,能够用one-hot矢量表示。total loss设计为软目标与硬目标所对应的交叉熵的加权平均(表示为KD loss与CE loss),其中软目标交叉熵的加权系数越大,代表迁移诱导越依赖教师网络的贡献,这对训练初期阶段是颇有必要的,有助于让学生网络更轻松的鉴别简单样本,但训练后期须要适当减少软目标的比重,让真实标注帮助鉴别困难样本。另外,教师网络的推理性能一般要优于学生网络,而模型容量则无具体限制,且教师网络推理精度越高,越有利于学生网络的学习。git

教师网络与学生网络也能够联合训练,此时教师网络的暗知识及学习方式都会影响学生网络的学习,具体以下(式中三项分别为教师网络softmax输出的交叉熵loss、学生网络softmax输出的交叉熵loss、以及教师网络数值输出与学生网络softmax输出的交叉熵loss):github

联合训练的Paper地址:https://arxiv.org/abs/1711.05852算法

二、Exploring Knowledge Distillation of Deep Neural Networks for Efficient Hardware Solutions

这篇文章将total loss从新定义以下:缓存

GitHub地址:https://github.com/peterliht/knowledge-distillation-pytorch网络

total loss的Pytorch代码以下,引入了精简网络输出与教师网络输出的KL散度,并在诱导训练期间,先将teacher network的预测输出缓存到CPU内存中,能够减轻GPU显存的overhead:ide

def loss_fn_kd(outputs, labels, teacher_outputs, params):
    """
    Compute the knowledge-distillation (KD) loss given outputs, labels.
    "Hyperparameters": temperature and alpha
    NOTE: the KL Divergence for PyTorch comparing the softmaxs of teacher
    and student expects the input tensor to be log probabilities! See Issue #2
    """
    alpha = params.alpha
    T = params.temperature
    KD_loss = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1),
                             F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + \
                             F.cross_entropy(outputs, labels) * (1. - alpha)

    return KD_loss

三、Ensemble of Multiple Teachers

第一种算法:多个教师网络输出的soft label按加权组合,构成统一的soft label,而后指导学生网络的训练:性能

第二种算法:因为加权平均方式会弱化、平滑多个教师网络的预测结果,所以能够随机选择某个教师网络的soft label做为guidance:学习

第三种算法:一样地,为避免加权平均带来的平滑效果,首先采用教师网络输出的soft label从新标注样本、增广数据、再用于模型训练,该方法可以让模型学会从更多视角观察同同样本数据的不一样功能:ui

Paper地址:

https://www.researchgate.net/publication/319185356_Efficient_Knowledge_Distillation_from_an_Ensemble_of_Teachers

四、Hint-based Knowledge Transfer

为了可以诱导训练更深、更纤细的学生网络(deeper and thinner FitNet),须要考虑教师网络中间层的Feature Maps(做为Hint),用来指导学生网络中相应的Guided layer。此时须要引入L2 loss指导训练过程,该loss计算为教师网络Hint layer与学生网络Guided layer输出Feature Maps之间的差异,若两者输出的Feature Maps形状不一致,Guided layer须要经过一个额外的回归层,具体以下:

具体训练过程分两个阶段完成:第一个阶段利用Hint-based loss诱导学生网络达到一个合适的初始化状态(只更新W_Guided与W_r);第二个阶段利用教师网络的soft label指导整个学生网络的训练(即知识蒸馏),且total loss中soft target相关部分所占比重逐渐下降,从而让学生网络可以全面辨别简单样本与困难样本(教师网络可以有效辨别简单样本,而困难样本则须要借助真实标注,即hard target):

Paper地址:https://arxiv.org/abs/1412.6550

GitHub地址:https://github.com/adri-romsor/FitNets

五、Attention to Attention Transfer

经过网络中间层的attention map,完成teacher network与student network之间的知识迁移。考虑给定的tensor A,基于activation的attention map能够定义为以下三种之一:

随着网络层次的加深,关键区域的attention-level也随之提升。文章最后采用了第二种形式的attention map,取p=2,而且activation-based attention map的知识迁移效果优于gradient-based attention map,loss定义及迁移过程以下:

Paper地址:https://arxiv.org/abs/1612.03928

GitHub地址:https://github.com/szagoruyko/attention-transfer

六、Flow of the Solution Procedure

暗知识亦可表示为训练的求解过程(FSP: Flow of the Solution Procedure),教师网络或学生网络的FSP矩阵定义以下(Gram形式的矩阵):

训练的第一阶段:最小化教师网络FSP矩阵与学生网络FSP矩阵之间的L2 Loss,初始化学生网络的可训练参数:

训练的第二阶段:在目标任务的数据集上fine-tune学生网络。从而达到知识迁移、快速收敛、以及迁移学习的目的。

Paper地址:

http://openaccess.thecvf.com/content_cvpr_2017/papers/Yim_A_Gift_From_CVPR_2017_paper.pdf

七、Knowledge Distillation with Adversarial Samples Supporting Decision Boundary

从分类的决策边界角度分析,知识迁移过程亦可理解为教师网络诱导学生网络有效鉴别决策边界的过程,鉴别能力越强意味着模型的泛化能力越好:

文章首先利用对抗攻击策略(adversarial attacking)将基准类样本(base class sample)转为目标类样本、且位于决策边界附近(BSS: boundary supporting sample),进而利用对抗生成的样本诱导学生网络的训练,可有效提高学生网络对决策边界的鉴别能力。文章采用迭代方式生成对抗样本,须要沿loss function(基准类得分与目标类得分之差)的梯度负方向调整样本,直到知足中止条件为止:

loss function:

沿loss function的梯度负方向调整样本:

中止条件(只要知足三者之一):

结合对抗生成的样本,利用教师网络训练学生网络所需的total loss包含CE loss、KD loss以及boundary supporting loss(BS loss):

Paper地址:https://arxiv.org/abs/1805.05532

八、Label Refinery:Improving ImageNet Classification through Label Progression

这篇文章经过迭代式的诱导训练,主要解决训练期间样本的crop与label不一致的问题,以加强label的质量,从而进一步加强模型的泛化能力:

诱导过程当中,total loss表示为本次迭代(t>1)网络的预测输出(几率分布)与上一次迭代输出(Label Refinery:相似于教师网络的角色)的KL散度:

文章实验部分代表,不只能够用训练网络做为Label Refinery Network,也能够用其余高质量网络(如Resnet50)做为Label Refinery Network。并在诱导过程当中,可以对抗生成样本,实现数据加强。

GitHub地址:https://github.com/hessamb/label-refinery

九、Miscellaneous

-------- 知识蒸馏能够与量化结合使用,考虑了中间层Feature Maps之间的关系,可参考:

http://www.javashuo.com/article/p-hwqoavyt-dx.html

-------- 知识蒸馏与Hint Learning相结合,能够训练精简的Faster-RCNN,可参考:

http://www.javashuo.com/article/p-rrjfvmsa-he.html

-------- 知识蒸馏在Transformer模型压缩方面,主要采用Self-attention Knowledge Distillation,可参考:

https://blog.csdn.net/nature553863/article/details/106855786

-------- 模型压缩方面,更为详细的讨论,请参考:

http://www.javashuo.com/article/p-gtoxopos-ns.html