论文笔记:蒸馏网络(Distilling the Knowledge in Neural Network)

Distilling the Knowledge in Neural Network
Geoffrey Hinton, Oriol Vinyals, Jeff Dean
preprint arXiv:1503.02531, 2015
NIPS 2014 Deep Learning Workshop算法

简单总结

主要工做(What)网络

  1. “蒸馏”(distillation):把大网络的知识压缩成小网络的一种方法
  2. “专用模型”(specialist models):对于一个大网络,能够训练多个专用网络来提高大网络的模型表现

具体作法(How)机器学习

  1. 蒸馏:先训练好一个大网络,在最后的softmax层使用合适的温度参数T,最后训练获得的几率称为“软目标”。以这个软目标和真实标签做为目标,去训练一个比较小的网络,训练的时候也使用在大模型中肯定的温度参数T
  2. 专用模型:对于一个已经训练好的大网络,能够训练一系列的专用模型,每一个专用模型只训练一部分专用的类以及一个“不属于这些专用类的其它类”,好比专用模型1训练的类包括“显示器”,“鼠标”,“键盘”,...,“其它”;专用模型2训练的类包括“玻璃杯”,“保温杯”,“塑料杯”,“其它“。最后以专用模型和大网络的预测输出做为目标,训练一个最终的网络来拟合这个目标。

意义(Why)函数

  1. 蒸馏把大网络压成小网络,这样就能够先在训练阶段花费大精力训练一个大网络,而后在部署阶段以较小的计算代价来产生一个较小的网络,同时保持必定的网络预测表现。
  2. 对于一个已经训练好的大网络,若是要去作集成的话计算开销是很大的,能够在这个基础上训练一系列专用模型,由于这些模型一般比较小,因此训练会快不少,并且有了这些专用模型的输出能够获得一个软目标,实验证实使用软目标训练能够减少过拟合。最后根据这个大网络和一系列专用模型的输出做为目标,训练一个最终的网络,能够获得不错的表现,并且不须要对大网络作大量的集成计算

Abstract

提升机器学习算法表现的一个简单方法就是,训练不一样模型而后对预测结果取平均。
可是要训练多个模型会带来太高的计算复杂度和部署难度。
能够将集成的知识压缩在单一的模型中。
论文使用这种方法在MNIST上作实验,发现取得了不错的效果。
论文还介绍了一种新型的集成,包括一个或多个完整模型和专用模型,可以学习区分完整模型容易混淆的细粒度的类别。学习

1 Introduction

昆虫有幼虫期和成虫期,幼虫期主要行为是吸取营养,成虫期主要行为是生长繁殖。
相似地,大规模机器学习应用能够分为训练阶段和部署阶段,训练阶段不要求实时操做,容许训练一个复杂缓慢的模型,这个模型能够是分别训练多个模型的集成,也能够是单独的一个很大的带有强正则好比dropout的模型。
一旦模型训练好,能够用不一样的训练,这里称为“蒸馏”,去把知识转移到更适合部署的小模型上。测试

复杂模型学习区分大量的类,一般的训练目标是最大化正确答案的平均log几率,这么作有一个反作用就是训练模型同时也会给全部的错误答案分配几率,即便这个几率很小,而有一些几率会比其它的大不少。错误答案的相对几率体现了复杂模型的泛化能力。举个例子,宝马的图像被错认为垃圾箱的几率很低,可是这被个错认为垃圾桶的几率相比于被错认为胡萝卜的几率来讲,是很大的。(能够认为模型不止学到了训练集中的宝马图像特征,还学到了一些别的特征,好比和垃圾桶共有的一些特征,这样就可能捕捉到在新的测试集上的宝马出现这些的特征,这就是泛化能力的体现)google

将复杂模型转为小模型须要保留模型的泛化能力,一个方法就是用复杂模型产生的分类几率做为“软目标”来训练小模型。
当软目标的熵值较高时,相对于硬目标,每一个训练样本提供更多的信息,训练样本之间会有更小的梯度方差。
因此小模型常常能够被训练在小数据集上,并且可使用更高的学习率。ci

像MNIST这种分类任务,复杂模型能够产生很好的表现,大部分信息分布在小几率的软目标中。
为了规避这个问题,Caruana和他的合做者们使用softmax输出前的units值,而不是softmax后的几率,最小化复杂模型和简单模型的units的平方偏差来训练小模型。
而更通用的方法,蒸馏法,先提升softmax的温度参数直到模型能产生合适的软目标。而后在训练小模型匹配软目标的时候使用相同的温度T。部署

被用于训练小模型的转移训练集能够包括未打标签的数据(能够没有原始的实际标签,由于能够经过复杂模型获取一个软目标做为标签),或者使用原始的数据集,使用原始数据集能够获得更好的表现。get

2 Distillation

softmax公式: $ q_{i} = \frac{exp(z_{i}/T)}{\sum_{j}^{ }exp(z_{j}/T)} $
其中温度参数T一般设置为1,T越大能够获得更“软”的几率分布。
T越大,不一样激活值的几率差别越小,全部激活值的几率趋于相同;T越小,不一样激活值的几率差别越大
在蒸馏训练的时候使用较大的T的缘由是,较小的T对于那些远小于平均激活值的单元会给予更少的关注,而这些单元是有用的,使用较高的T可以捕捉这些信息

最简单的蒸馏形式就是,训练小模型的时候,以复杂模型获得的“软目标”为目标,采用复杂模型中的较高的T,训练完以后把T改成1。

当部分或所有转移训练集的正确标签已知时,蒸馏获得的模型会更优。一个方法就是使用正确标签来修改软目标。
可是咱们发现一个更好的方法,简单对两个不一样的目标函数进行权重平均,第一个目标函数是和复杂模型的软目标作一个交叉熵,使用的复杂模型的温度T;第二个目标函数是和正确标签的交叉熵,温度设置为1。咱们发现第二个目标函数被分配一个低权重时一般会取得最好的结果。

3 Preliminary experiments on MNIST

net layers units of each layer activation regularization test errors
single net1 2 1600 relu dropout 67
single net2 2 800 relu no 146

(防止表格黏在一块儿)

net large net small net temperature test errors
distilled net single net1 single net2 20 74

第一个表格中是两个单独的网络,一个大网络和一个小网络。
第二个表格是使用了蒸馏的方法,先训练大网络,而后根据大网络的“软目标”结果和温度T来训练小网络。
能够看到,经过蒸馏的方法将大网络中的知识压缩到小网络中,取得了不错的效果。

4 Experiments on speech recognition

system Test Frame Accuracy Word Error Rate on dev set
baseline 58.9% 10.9%
10XEnsemble 61.1% 10.7%
Distilled model 60.8% 10.7%

其中basline的配置为

  • 8 层,每层2560个relu单元
  • softmax层的单元数为14000
  • 训练样本大小约为 700M,2000个小时的语音文本数据

10XEnsemble是对baseline训练10次(随机初始化为不一样参数)而后取平均

蒸馏模型的配置为

  • 使用的候选温度为{1, 2, 5, 10}, 其中T为2时表现最好
  • hard target 的目标函数给予0.5的相对权重

能够看到,相对于10次集成后的模型表现提高,蒸馏保留了超过80%的效果提高

5 Training ensembles of specialists on very big datasets

训练一个大的集成模型能够利用并行计算来训练,训练完成后把大模型蒸馏成小模型,可是另外一个问题就是,训练自己就要花费大量的时间,这一节介绍的是如何学习专用模型集合,集合中的每一个模型集中于不一样的容易混淆的子类集合,这样能够减少计算需求。专用模型的主要问题是容易集中于区分细粒度特征而致使过拟合,可使用软目标来防止过拟合。

5.1 JFT数据集

JFT是一个谷歌的内部数据集,有1亿的图像,15000个标签。google用一个深度卷积神经网络,训练了将近6个月。
咱们须要更快的方法来提高baseline模型。

5.2 专用模型

将一个复杂模型分为两部分,一部分是一个用于训练全部数据的通用模型,另外一部分是不少个专用模型,每一个专用模型训练的数据集是一个容易混淆的子类集合。这些专用模型的softmax结合全部不关心的类为一类来使模型更小。

为了减小过拟合,共享学习到的低水平特征,每一个专用模型用通用模型的权重进行初始化。另外,专用模型的训练样本一半来自专用子类集合,另外一半从剩余训练集中随机抽取。

5.3 将子类分配到专用模型

专用模型的子类分组集中于容易混淆的那些类别,虽然计算出了混淆矩阵来寻找聚类,可是可使用一种更简单的办法,不须要使用真实标签来构建聚类。对通用模型的预测结果计算协方差,根据协方差把常常一块儿预测的类做为其中一个专用模型的要预测的类别。几个简单的例子以下。

JFT 1: Tea party; Easter; Bridal shower; Baby shower; Easter Bunny; ...
JFT 2: Bridge; Cable-stayed bridge; Suspension bridge; Viaduct; Chimney; ...
JFT 3: Toyota Corolla E100; Opel Signum; Opel Astra; Mazda Familia; ...

5.4 实验表现

system Conditional Test Accuracy Test Accuracy
baseline 43.1% 25.0%
61 specialist models 45.9% 26.1%

6 Soft Targets as Regularizers

对于前面提到过的,对于大量数据训练好的语音baseline模型,用更少的数据去拟合这个模型的时候,使用软目标能够达到更好的效果,减少过拟合。实验结果以下。

system & training set Train Frame Accuracy Test Frame Accuracy
baseline(100% training set) 63.4% 58.9%
baseline(3% training set) 67.3% 44.5%
soft targets(3% training set) 65.4% 57.0%
相关文章
相关标签/搜索