Distilling the Knowledge in Neural Network
Geoffrey Hinton, Oriol Vinyals, Jeff Dean
preprint arXiv:1503.02531, 2015
NIPS 2014 Deep Learning Workshop算法
主要工做(What)网络
具体作法(How)机器学习
意义(Why)函数
提升机器学习算法表现的一个简单方法就是,训练不一样模型而后对预测结果取平均。
可是要训练多个模型会带来太高的计算复杂度和部署难度。
能够将集成的知识压缩在单一的模型中。
论文使用这种方法在MNIST上作实验,发现取得了不错的效果。
论文还介绍了一种新型的集成,包括一个或多个完整模型和专用模型,可以学习区分完整模型容易混淆的细粒度的类别。学习
昆虫有幼虫期和成虫期,幼虫期主要行为是吸取营养,成虫期主要行为是生长繁殖。
相似地,大规模机器学习应用能够分为训练阶段和部署阶段,训练阶段不要求实时操做,容许训练一个复杂缓慢的模型,这个模型能够是分别训练多个模型的集成,也能够是单独的一个很大的带有强正则好比dropout的模型。
一旦模型训练好,能够用不一样的训练,这里称为“蒸馏”,去把知识转移到更适合部署的小模型上。测试
复杂模型学习区分大量的类,一般的训练目标是最大化正确答案的平均log几率,这么作有一个反作用就是训练模型同时也会给全部的错误答案分配几率,即便这个几率很小,而有一些几率会比其它的大不少。错误答案的相对几率体现了复杂模型的泛化能力。举个例子,宝马的图像被错认为垃圾箱的几率很低,可是这被个错认为垃圾桶的几率相比于被错认为胡萝卜的几率来讲,是很大的。(能够认为模型不止学到了训练集中的宝马图像特征,还学到了一些别的特征,好比和垃圾桶共有的一些特征,这样就可能捕捉到在新的测试集上的宝马出现这些的特征,这就是泛化能力的体现)google
将复杂模型转为小模型须要保留模型的泛化能力,一个方法就是用复杂模型产生的分类几率做为“软目标”来训练小模型。
当软目标的熵值较高时,相对于硬目标,每一个训练样本提供更多的信息,训练样本之间会有更小的梯度方差。
因此小模型常常能够被训练在小数据集上,并且可使用更高的学习率。ci
像MNIST这种分类任务,复杂模型能够产生很好的表现,大部分信息分布在小几率的软目标中。
为了规避这个问题,Caruana和他的合做者们使用softmax输出前的units值,而不是softmax后的几率,最小化复杂模型和简单模型的units的平方偏差来训练小模型。
而更通用的方法,蒸馏法,先提升softmax的温度参数直到模型能产生合适的软目标。而后在训练小模型匹配软目标的时候使用相同的温度T。部署
被用于训练小模型的转移训练集能够包括未打标签的数据(能够没有原始的实际标签,由于能够经过复杂模型获取一个软目标做为标签),或者使用原始的数据集,使用原始数据集能够获得更好的表现。get
softmax公式: $ q_{i} = \frac{exp(z_{i}/T)}{\sum_{j}^{ }exp(z_{j}/T)} $
其中温度参数T一般设置为1,T越大能够获得更“软”的几率分布。
(T越大,不一样激活值的几率差别越小,全部激活值的几率趋于相同;T越小,不一样激活值的几率差别越大)
(在蒸馏训练的时候使用较大的T的缘由是,较小的T对于那些远小于平均激活值的单元会给予更少的关注,而这些单元是有用的,使用较高的T可以捕捉这些信息)
最简单的蒸馏形式就是,训练小模型的时候,以复杂模型获得的“软目标”为目标,采用复杂模型中的较高的T,训练完以后把T改成1。
当部分或所有转移训练集的正确标签已知时,蒸馏获得的模型会更优。一个方法就是使用正确标签来修改软目标。
可是咱们发现一个更好的方法,简单对两个不一样的目标函数进行权重平均,第一个目标函数是和复杂模型的软目标作一个交叉熵,使用的复杂模型的温度T;第二个目标函数是和正确标签的交叉熵,温度设置为1。咱们发现第二个目标函数被分配一个低权重时一般会取得最好的结果。
net | layers | units of each layer | activation | regularization | test errors |
---|---|---|---|---|---|
single net1 | 2 | 1600 | relu | dropout | 67 |
single net2 | 2 | 800 | relu | no | 146 |
(防止表格黏在一块儿)
net | large net | small net | temperature | test errors |
---|---|---|---|---|
distilled net | single net1 | single net2 | 20 | 74 |
(第一个表格中是两个单独的网络,一个大网络和一个小网络。)
(第二个表格是使用了蒸馏的方法,先训练大网络,而后根据大网络的“软目标”结果和温度T来训练小网络。)
(能够看到,经过蒸馏的方法将大网络中的知识压缩到小网络中,取得了不错的效果。)
system | Test Frame Accuracy | Word Error Rate on dev set |
---|---|---|
baseline | 58.9% | 10.9% |
10XEnsemble | 61.1% | 10.7% |
Distilled model | 60.8% | 10.7% |
其中basline的配置为
10XEnsemble是对baseline训练10次(随机初始化为不一样参数)而后取平均
蒸馏模型的配置为
能够看到,相对于10次集成后的模型表现提高,蒸馏保留了超过80%的效果提高
训练一个大的集成模型能够利用并行计算来训练,训练完成后把大模型蒸馏成小模型,可是另外一个问题就是,训练自己就要花费大量的时间,这一节介绍的是如何学习专用模型集合,集合中的每一个模型集中于不一样的容易混淆的子类集合,这样能够减少计算需求。专用模型的主要问题是容易集中于区分细粒度特征而致使过拟合,可使用软目标来防止过拟合。
JFT是一个谷歌的内部数据集,有1亿的图像,15000个标签。google用一个深度卷积神经网络,训练了将近6个月。
咱们须要更快的方法来提高baseline模型。
将一个复杂模型分为两部分,一部分是一个用于训练全部数据的通用模型,另外一部分是不少个专用模型,每一个专用模型训练的数据集是一个容易混淆的子类集合。这些专用模型的softmax结合全部不关心的类为一类来使模型更小。
为了减小过拟合,共享学习到的低水平特征,每一个专用模型用通用模型的权重进行初始化。另外,专用模型的训练样本一半来自专用子类集合,另外一半从剩余训练集中随机抽取。
专用模型的子类分组集中于容易混淆的那些类别,虽然计算出了混淆矩阵来寻找聚类,可是可使用一种更简单的办法,不须要使用真实标签来构建聚类。对通用模型的预测结果计算协方差,根据协方差把常常一块儿预测的类做为其中一个专用模型的要预测的类别。几个简单的例子以下。
JFT 1: Tea party; Easter; Bridal shower; Baby shower; Easter Bunny; ...
JFT 2: Bridge; Cable-stayed bridge; Suspension bridge; Viaduct; Chimney; ...
JFT 3: Toyota Corolla E100; Opel Signum; Opel Astra; Mazda Familia; ...
system | Conditional Test Accuracy | Test Accuracy |
---|---|---|
baseline | 43.1% | 25.0% |
61 specialist models | 45.9% | 26.1% |
对于前面提到过的,对于大量数据训练好的语音baseline模型,用更少的数据去拟合这个模型的时候,使用软目标能够达到更好的效果,减少过拟合。实验结果以下。
system & training set | Train Frame Accuracy | Test Frame Accuracy |
---|---|---|
baseline(100% training set) | 63.4% | 58.9% |
baseline(3% training set) | 67.3% | 44.5% |
soft targets(3% training set) | 65.4% | 57.0% |