知识蒸馏Knowledge Distillation

知识蒸馏是模型压缩的一个重要方法,本文简要介绍了什么是知识蒸馏。git

知识蒸馏Knowledge Distillation

1.什么是知识蒸馏

我浅谈一些个人见解,详细内容能够参考这篇文章
[https://zhuanlan.zhihu.com/p/90049906]web

简单来讲,就是咱们通常训练模型时,可能为了有一个好的效果,就会加大网络深度,或者用一些复杂的网络,这样参数量就会好大…那么这么一个模型怎么弄到移动端呢?怎么能使得运行速度实时呢?网络

因此人们提出搞模型压缩!svg

知识蒸馏就是一类模型压缩方法,先训练大模型,再去引导作一个小模型。函数

它是怎么干的呢?xml

大模型会训练出一系列的softmax几率值,这样,原来咱们须要让新模型的softmax分布与真实标签匹配,如今只须要让新模型与原模型在给定输入下的softmax分布匹配了。直观来看,后者比前者具备这样一个优点:通过训练后的原模型,其softmax分布包含有必定的知识——真实标签只能告诉咱们,某个图像样本是一辆宝马,不是一辆垃圾车,也不是一颗萝卜;而通过训练的softmax可能会告诉咱们,它最多是一辆宝马,不大多是一辆垃圾车,但毫不多是一颗萝卜。blog

随后怎么作,简单来讲就是能够小模型的几率z要逼近原模型的v,直接用下面损失也能够
在这里插入图片描述图片

2.知识蒸馏怎么作

  • 第一步:在训练集上训练好一个大模型A(一般叫作teacher model)
  • 第二步:在transfer set(能够和训练集是同一个数据集)上利用大模型A产生给每个样本生成一个soft target(有利用一个temperature参数对logits进行平滑)
  • 第三步:在transfer set上对student model B进行训练,损失函数由两部分组成,都是交叉熵损失,只不过一个是拟合soft target,另一个是拟合ground truth的hard target(如二分类中的0和1),其中在拟合hard target的损失函数和普通分类损失保持一致,在拟合soft target的损失函数时也利用了一个一样的temperature参数T
  • 第四步:保留student model进行线上预测,这个时候去掉soft target那一路,只保留普通分类的softmax
    在这里插入图片描述

参考文献

[1]Distilling the Knowledge in a Neural Network论文笔记
https://zhuanlan.zhihu.com/p/74901192
[2]知识蒸馏是什么?一份入门随笔
https://zhuanlan.zhihu.com/p/90049906get