点击上方“AI公园”,关注公众号,选择加“星标“或“置顶”git
做者:Sayak Paulweb
编译:ronghuaiyang
微信
从各个层次给你们讲解模型的知识蒸馏的相关内容,并经过实际的代码给你们进行演示。网络
公众号后台回复“模型蒸馏”,下载已打包好的代码。app
本报告讨论了很是厉害模型优化技术 —— 知识蒸馏,并给你们过了一遍相关的TensorFlow的代码。
编辑器
“模型集成是一个至关有保证的方法,能够得到2%的准确性。“ —— Andrej Karpathy分布式
我绝对赞成!然而,部署重量级模型的集成在许多状况下并不老是可行的。有时,你的单个模型可能太大(例如GPT-3),以致于一般不可能将其部署到资源受限的环境中。这就是为何咱们一直在研究一些模型优化方法 ——量化和剪枝。在这个报告中,咱们将讨论一个很是厉害的模型优化技术 —— 知识蒸馏。函数
Softmax告诉了咱们什么?
当处理一个分类问题时,使用softmax做为神经网络的最后一个激活单元是很是典型的用法。这是为何呢?由于softmax函数接受一组logit为输入并输出离散类别上的几率分布。好比,手写数字识别中,神经网络可能有较高的置信度认为图像为1。不过,也有轻微的可能性认为图像为7。若是咱们只处理像[1,0]这样的独热编码标签(其中1和0分别是图像为1和7的几率),那么这些信息就没法得到。性能
人类已经很好地利用了这种相对关系。更多的例子包括,长得很像猫的狗,棕红色的,猫同样的老虎等等。正如Hinton等人所认为的学习
一辆宝马被误认为是一辆垃圾车的可能性很小,但被误认为是一个胡萝卜的可能性仍然要高不少倍。
这些知识能够帮助咱们在各类状况下进行极好的归纳。这个思考过程帮助咱们更深刻地了解咱们的模型对输入数据的想法。它应该与咱们考虑输入数据的方式一致。
因此,如今该作什么?一个迫在眉睫的问题可能会忽然出如今咱们的脑海中 —— 咱们在神经网络中使用这些知识的最佳方式是什么?让咱们在下一节中找出答案。
使用Softmax的信息来教学 —— 知识蒸馏
softmax信息比独热编码标签更有用。在这个阶段,咱们能够获得:
-
训练数据 -
训练好的神经网络在测试数据上表现良好
咱们如今感兴趣的是使用咱们训练过的网络产生的输出几率。
考虑教人去认识MNIST数据集的英文数字。你的学生可能会问 —— 那个看起来像7吗?若是是这样的话,这绝对是个好消息,由于你的学生,确定知道1和7是什么样子。做为一名教师,你可以把你的数字知识传授给你的学生。这种想法也有可能扩展到神经网络。
知识蒸馏的高层机制
因此,这是一个高层次的方法:
-
训练一个在数据集上表现良好神经网络。这个网络就是“教师”模型。 -
使用教师模型在相同的数据集上训练一个学生模型。这里的问题是,学生模型的大小应该比老师的小得多。
本工做流程简要阐述了知识蒸馏的思想。
为何要小?这不是咱们想要的吗?将一个轻量级模型部署到生产环境中,从而达到足够的性能。
用图像分类的例子来学习
对于一个图像分类的例子,咱们能够扩展前面的高层思想:
-
训练一个在图像数据集上表现良好的教师模型。在这里,交叉熵损失将根据数据集中的真实标签计算。 -
在相同的数据集上训练一个较小的学生模型,可是使用来自教师模型(softmax输出)的预测做为ground-truth标签。这些softmax输出称为软标签。稍后会有更详细的介绍。
咱们为何要用软标签来训练学生模型?
请记住,在容量方面,咱们的学生模型比教师模型要小。所以,若是你的数据集足够复杂,那么较小的student模型可能不太适合捕捉训练目标所需的隐藏表示。咱们在软标签上训练学生模型来弥补这一点,它提供了比独热编码标签更有意义的信息。在某种意义上,咱们经过暴露一些训练数据集来训练学生模型来模仿教师模型的输出。
但愿这能让大家对知识蒸馏有一个直观的理解。在下一节中,咱们将更详细地了解学生模型的训练机制。
知识蒸馏中的损失函数
为了训练学生模型,咱们仍然可使用教师模型的软标签以及学生模型的预测来计算常规交叉熵损失。学生模型颇有可能对许多输入数据点都有信心,而且它会预测出像下面这样的几率分布:
扩展Softmax
这些弱几率的问题是,它们没有捕捉到学生模型有效学习所需的信息。例如,若是几率分布像[0.99, 0.01]
,几乎不可能传递图像具备数字7的特征的知识。
Hinton等人解决这个问题的方法是,在将原始logits传递给softmax以前,将教师模型的原始logits按必定的温度进行缩放。这样,就会在可用的类标签中获得更普遍的分布。而后用一样的温度用于训练学生模型。
咱们能够把学生模型的修正损失函数写成这个方程的形式:
其中,pi是教师模型获得软几率分布,si的表达式为:
def get_kd_loss(student_logits, teacher_logits,
true_labels, temperature,
alpha, beta):
teacher_probs = tf.nn.softmax(teacher_logits / temperature)
kd_loss = tf.keras.losses.categorical_crossentropy(
teacher_probs, student_logits / temperature,
from_logits=True)
return kd_loss
使用扩展Softmax来合并硬标签
Hinton等人还探索了在真实标签(一般是独热编码)和学生模型的预测之间使用传统交叉熵损失的想法。当训练数据集很小,而且软标签没有足够的信号供学生模型采集时,这一点尤为有用。
当它与扩展的softmax相结合时,这种方法的工做效果明显更好,而总体损失函数成为二者之间的加权平均。
def get_kd_loss(student_logits, teacher_logits,
true_labels, temperature,
alpha, beta):
teacher_probs = tf.nn.softmax(teacher_logits / temperature)
kd_loss = tf.keras.losses.categorical_crossentropy(
teacher_probs, student_logits / temperature,
from_logits=True)
ce_loss = tf.keras.losses.sparse_categorical_crossentropy(
true_labels, student_logits, from_logits=True)
total_loss = (alpha * kd_loss) + (beta * ce_loss)
return total_loss / (alpha + beta)
建议β的权重小于α。
在原始Logits上进行操做
Caruana等人操做原始logits,而不是softmax值。这个工做流程以下:
-
这部分保持相同 —— 训练一个教师模型。这里交叉熵损失将根据数据集中的真实标签计算。 -
如今,为了训练学生模型,训练目标变成分别最小化来自教师和学生模型的原始对数之间的平均平方偏差。
mse = tf.keras.losses.MeanSquaredError()
def mse_kd_loss(teacher_logits, student_logits):
return mse(teacher_logits, student_logits)
使用这个损失函数的一个潜在缺点是它是无界的。原始logits能够捕获噪声,而一个小模型可能没法很好的拟合。这就是为何为了使这个损失函数很好地适合蒸馏状态,学生模型须要更大一点。
Tang等人探索了在两个损失之间插值的想法:扩展softmax和MSE损失。数学上,它看起来是这样的:
根据经验,他们发现当α = 0时,(在NLP任务上)能够得到最佳的性能。
若是你在这一点上感到有点不知怎么办,不要担忧。但愿经过代码,事情会变得清楚。
一些训练方法
在本节中,我将向你提供一些在使用知识蒸馏时能够考虑的训练方法。
使用数据加强
他们在NLP数据集上展现了这个想法,但这也适用于其余领域。为了更好地指导学生模型训练,使用数据加强会有帮助,特别是当你处理的数据较少的时候。由于咱们一般保持学生模型比教师模型小得多,因此咱们但愿学生模型可以得到更多不一样的数据,从而更好地捕捉领域知识。
使用标记的和未标记的数据训练学生模型
在像Noisy Student Training和SimCLRV2这样的文章中,做者在训练学生模型时使用了额外的未标记数据。所以,你将使用你的teacher模型来生成未标记数据集上的ground-truth分布。这在很大程度上有助于提升模型的可泛化性。这种方法只有在你所处理的数据集中有未标记数据可用时才可行。有时,状况可能并不是如此(例如,医疗保健)。Xie等人探索了数据平衡和数据过滤等技术,以缓解在训练学生模型时合并未标记数据可能出现的问题。
在训练教师模型时不要使用标签平滑
标签平滑是一种技术,用来放松由模型产生的高可信度预测。它有助于减小过拟合,但不建议在训练教师模型时使用标签平滑,由于不管如何,它的logits是按必定的温度缩放的。所以,通常不推荐在知识蒸馏的状况下使用标签平滑。
使用更高的温度值
Hinton等人建议使用更高的温度值来soften教师模型预测的分布,这样软标签能够为学生模型提供更多的信息。这在处理小型数据集时特别有用。对于更大的数据集,信息能够经过训练样本的数量来得到。
实验结果
让咱们先回顾一下实验设置。我在实验中使用了Flowers数据集。除非另外指定,我使用如下配置:
-
我使用MobileNetV2做为基本模型进行微调,学习速度设置为 1e-5
,Adam做为优化器。 -
咱们将τ设置为5。 -
α = 0.9,β = 0.1。 -
对于学生模型,使用下面这个简单的结构:
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 222, 222, 64) 1792
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 55, 55, 64) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 53, 53, 128) 73856
_________________________________________________________________
global_average_pooling2d_3 ( (None, 128) 0
_________________________________________________________________
dense_3 (Dense) (None, 512) 66048
_________________________________________________________________
dense_4 (Dense) (None, 5) 2565
=================================================================
-
在训练学生模型时,我使用Adam做为优化器,学习速度为 1e-2
。 -
在使用数据加强训练student模型的过程当中,我使用了与上面提到的相同的默认超参数的加权平均损失。
学生模型基线
为了使性能比较公平,咱们还从头开始训练浅的CNN并观察它的性能。注意,在本例中,我使用Adam做为优化器,学习速率为1e-3
。
训练循环
在看到结果以前,我想说明一下训练循环,以及如何在经典的model.fit()
调用中包装它。这就是训练循环的样子:
def train_step(self, data):
images, labels = data
teacher_logits = self.trained_teacher(images)
with tf.GradientTape() as tape:
student_logits = self.student(images)
loss = get_kd_loss(teacher_logits, student_logits)
gradients = tape.gradient(loss, self.student.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.student.trainable_variables))
train_loss.update_state(loss)
train_acc.update_state(labels, tf.nn.softmax(student_logits))
t_loss, t_acc = train_loss.result(), train_acc.result()
train_loss.reset_states(), train_acc.reset_states()
return {"loss": t_loss, "accuracy": t_acc}
若是你已经熟悉了如何在TensorFlow 2中定制一个训练循环,那么train_step()函数应该是一个容易阅读的函数。注意get_kd_loss()
函数。这能够是咱们以前讨论过的任何损失函数。咱们在这里使用的是一个训练过的教师模型,这个模型咱们在前面进行了微调。经过这个训练循环,咱们能够建立一个能够经过.fit()
调用进行训练完整模型。
首先,建立一个扩展tf.keras.Model
的类。
class Student(tf.keras.Model):
def __init__(self, trained_teacher, student):
super(Student, self).__init__()
self.trained_teacher = trained_teacher
self.student = student
当你扩展tf.keras.Model
类的时候,能够将自定义的训练逻辑放到train_step()
函数中(由类提供)。因此,从总体上看,Student类应该是这样的:
class Student(tf.keras.Model):
def __init__(self, trained_teacher, student):
super(Student, self).__init__()
self.trained_teacher = trained_teacher
self.student = student
def train_step(self, data):
images, labels = data
teacher_logits = self.trained_teacher(images)
with tf.GradientTape() as tape:
student_logits = self.student(images)
loss = get_kd_loss(teacher_logits, student_logits)
gradients = tape.gradient(loss, self.student.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.student.trainable_variables))
train_loss.update_state(loss)
train_acc.update_state(labels, tf.nn.softmax(student_logits))
t_loss, t_acc = train_loss.result(), train_acc.result()
train_loss.reset_states(), train_acc.reset_states()
return {"train_loss": t_loss, "train_accuracy": t_acc}
你甚至能够编写一个test_step
来自定义模型的评估行为。咱们的模型如今能够用如下方式训练:
student = Student(teacher_model, get_student_model())
optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)
student.compile(optimizer)
student.fit(train_ds,
validation_data=validation_ds,
epochs=10)
这种方法的一个潜在优点是能够很容易地合并其余功能,好比分布式训练、自定义回调、混合精度等等。
使用
训练学生模型
用这个损失函数训练咱们的浅层学生模型,咱们获得~74%的验证精度。咱们看到,在epochs 8以后,损失开始增长。这代表,增强正则化可能会有所帮助。另外,请注意,超参数调优过程在这里有重大影响。在个人实验中,我没有作严格的超参数调优。为了更快地进行实验,我缩短了训练时间。
使用
训练学生模型
如今让咱们看看在蒸馏训练目标中加入ground truth标签是否有帮助。在β = 0.1和α = 0.1的状况下,咱们获得了大约71%的验证准确性。再次代表,更强的正则化和更长的训练时间会有所帮助。
使用
训练学生模型
使用了MSE的损失,咱们能够看到验证精度大幅降低到~56%。一样的损失也出现了相似的状况,这代表须要进行正则化。
请注意,这个损失函数是无界的,咱们的浅学生模型可能没法处理随之而来的噪音。让咱们尝试一个更深刻的学生模型。
在训练学生模型的时候使用数据加强
如前所述,学生模式比教师模式的容量更小。在处理较少的数据时,数据加强能够帮助训练学生模型。咱们验证一下。
数据增长的好处是很是明显的:
-
咱们有一个更好的损失曲线。 -
验证精度提升到84%。
温度(τ)的影响
在这个实验中,咱们研究温度对学生模型的影响。在这个设置中,我使用了相同的浅层CNN。
从上面的结果能够看出,当τ为1时,训练损失和训练精度均优于其它方法。对于验证损失,咱们能够看到相似的行为,可是在全部不一样的温度下,验证的准确性彷佛几乎是相同的。
最后,我想研究下微调基线模是否对学生模型有显著影响。
基线模型调优的效果
在此次实验中,我选择了 EfficientNet B0做为基础模型。让咱们先来看看我用它获得的微调结果。注意,如前所述,全部其余超参数都保持其默认值。
咱们在微调步骤中没有看到任何显著的改进。我想再次强调,我没有进行严格的超参数调优实验。基于我从EfficientNet B0获得的边际改进,我决定在之后的某个时间点进行进一步的实验。
第一行对应的是用加权平均损失训练的默认student model,其余行分别对应EfficientNet B0和MobileNetV2。注意,我没有包括在训练student模型时经过使用数据加强而获得的结果。
知识蒸馏的一个好处是,它与其余模型优化技术(如量化和修剪)无缝集成。因此,做为一个有趣的实验,我鼓励大家本身尝试一下。
总结
知识蒸馏是一种很是有前途的技术,特别适合于用于部署的目的。它的一个优势是,它能够与量化和剪枝很是无缝地结合在一块儿,从而在不影响精度的前提下进一步减少生产模型的尺寸。

英文原文:https://wandb.ai/authors/knowledge-distillation/reports/Distilling-Knowledge-in-Neural-Networks--VmlldzoyMjkxODk
请长按或扫描二维码关注本公众号
喜欢的话,请给我个好看吧!
本文分享自微信公众号 - AI公园(AI_Paradise)。
若有侵权,请联系 support@oschina.cn 删除。
本文参与“OSC源创计划”,欢迎正在阅读的你也加入,一块儿分享。