【推荐系统】知识蒸馏概述

一、 知识蒸馏是什么

知识蒸馏主要处理的是模型的有效性和效率之间的平衡问题:

模型越来越深、越来越复杂,导致模型上线后相应速度太慢,无法满足系统的低延迟要求。
在这里插入图片描述
知识蒸馏就是目前一种比较流行的解决此类问题的技术方向。
一般为teacher-student模式,主要思想是用一个复杂的、较大的teacher model去指导简单的、较小的student model的学习。
线上使用的是student小模型。
在这里插入图片描述

二、Distilling the Knowledge in a Neural Network

论文地址:https://arxiv.org/pdf/1503.02531.pdf

Knowledge distillation最早来自于hinton 2015年的一篇论文,在文中hinton提到:可以将一个大的、复杂的或者ensemble的模型获得知识transfer压缩到一个单个小模型中。

其主要思想是将训练好的teacher model输出的class probability作为soft target,让student model去学习:

  • An obvious way to transfer the generalization ability of the cumbersome model to a small model is to use the class probabilities produced by the cumbersome model as “soft targets” for training the small model.

文中引入了 soft target 这一概念:
q i = e x p ( z i / T ) j e x p ( z j / T ) q_i=\frac{exp(z_i/T)}{\sum_jexp(z_j/T)}

  • T T 是超参: Temperature
  • T = 1 T=1 时,就是平常使用的 s o f t m a x softmax
  • T T 越大 , 类别之间的概率分布越 softer(减小不同类别归属概率的两极分化程度)
    • T = 1 , [ 3 , 10 ] [ 0.0009 , 0.9991 ] T=1, [3,10] \rightarrow [0.0009,0.9991]
    • T = 10 , [ 3 , 10 ] [ 0.33 , 0.67 ] T=10, [3,10] \rightarrow [0.33,0.67]
    • T = 100 , [ 3 , 10 ] [ 0.48 , 0.52 ] T=100, [3,10] \rightarrow [0.48,0.52]

那么为什么要让student去学习这个soft target呢?

  • One of our main claims about using soft targets instead of hard targets is that a lot of helpful information can be carried in soft targets that could not possibly be encoded with a single hard target.

因为 soft target 会包含更多的信息:

  • 比如对于一个分类问题,区分三个类: 1, 2, 3
  • Hard Target: [0, 0, 1]
    • 信息量低
    • 只有在3处的index为1,其余都为0
  • Soft Target: [0.001, 0.149, 0.85]
    • 信息量大,拥有不同类之间的关系信息,
    • 在非3的部分也有概率,表征类与类之间的相似性
  • Temperature T的作用:
    • 交叉熵: p t l o g ( q s ) -p_tlog(q_s)
    • T = 5 [ 0.001 , 0.149 , 0.85 ] [ 0.31 , 0.32 , 0.37 ] T=5:[0.001, 0.149, 0.85] \rightarrow [0.31, 0.32, 0.37] ,soften之后可以使概率较小的第一个位置在交叉熵损失函数中也有贡献
    • 使类与类之间的关联信息更明显
  • soft target的作用类似于pretrain,大模型帮忙提取类标空间的更多关联信息,小模型在训练前获得prior knowlege,降低了神经网络的搜索空间,从而获得了泛化(generalization)能力

下图是文中比较的soft target和hard-target的实验结果:

第一行是使用全部数据+hart target训练的baseline,第二行为使用3%数据+hart target的训练和测试accuracy,第三行为使用3%数据+soft target的训练和测试accuracy。

可以发现,仅使用3%的数据+soft target就可以达到和baseline相当的表现。并且,文中提到第二种方式很容易过拟合,必须使用early-stopping,第三种方式不需要使用ealy-stopping,可见 soft target有regularizer的作用。
在这里插入图片描述
模型的训练和测试
在这里插入图片描述
训练分为两个阶段( H H 为交叉熵):

  • 先使用hard target训练teacher model, L o s s t e a c h e r = H ( y , f t ( x ) ) Loss_{teacher}=H(y,f_t(x))
  • 利用训练好的 teacher model 生成 Temperature = T的 soft targets
  • 训练student model:
    L o s s s t u d e n t = H ( y , f s ( x ) ) + λ H ( s o f t m a x ( z t T ) , s o f t m a x ( z s T ) ) Loss_{student}=H(y,f_s(x)) + \lambda H(softmax(\frac{z_t}{T}),softmax(\frac{z_s}{T}))
    文中提到 λ \lambda 应该大一点,因为可以推导出第二项loss相较于直接去算logits的MSE,做了 1 / T 2 1/T^2 的scale

三、Rocket launching

Rocket Launching: A Universal and Efficient Framework for Training Well-performing Light Net

阿里的rocket launching在KD的基础上做了一些改进,模型结构如下:

在这里插入图片描述
训练的loss为:
L ( x ; W S , W L , W B ) = H ( y , p ( x ) ) + H ( y , q ( x ) ) + λ L o s s h i n t L(x;W_S,W_L,W_B) = H(y, p(x)) + H(y, q(x))+\lambda Loss_{hint}
其中, L o s s h i n t Loss_{hint} 可以有如下几种形式:
Hint loss:

  • MSE of final softmax: L M S E ( x ) = p ( x ) q ( x ) 2 2 L_{MSE}(x)=||p(x)-q(x)||_2^2

  • MSE of logits before softmax activation: L M I M I C ( x ) = l ( x ) z ( x ) 2 2 L_{MIMIC}(x)=||l(x)-z(x)||_2^2

  • knowledge distillation: L K D ( x ) = H ( s o f t m a x ( l ( x ) T ) , s o f t m a x ( z ( x ) / T ) ) L_{KD}(x)=H(softmax(\frac{l(x)}{T}), softmax(\frac{z(x)}{/T}))

Rocket launching主要有如下几处改进:

  • 参数共享(Parameters sharing)
    student和teacher共享底层参数 W S W_S ,使得student model更好地训练
  • Teacher model和student model联合训练(Simultaneous training)
    student和teacher model同时训练,相较于两阶段的KD,可以极大缩短训练时间
  • 梯度阻隔(Gradient block)
    hint loss,也就是为了让student model学习teacher model的soft target所引入的loss,其梯度不会影响 W B W_B ,也就是teacher model自己参数的更新,因为:
    • 如果用hint loss的梯度去更新teacher model自己的参数的话,会使得teacher model的学习受到student model的极大影响,使得teacher model没有办法去直接学习任务本身
    • 而student model本身由于很简单,学习能力有限,因此会阻碍teacher model的学习
    • 进一步地,student model又会去学习teacher model输出的soft target, teacher model的学习被破坏,student model的学习也学不好
    • 因此,如下图所示,在训练过程中,hint loss并不会影响 W B W_B 的更新:
      在这里插入图片描述

自己用rocket launching结构,测试的cvr预测任务的测试auc如下(hint loss lambda选的 T 2 T^2 ):增大T,确实能够使得student的测试表现有所提升。

在这里插入图片描述

四、FITNETS: HINTS FOR THIN DEEP NETS

论文地址:https://arxiv.org/pdf/1412.6550.pdf

核心思想是利用teacher model中间层输出指导student model中间层的输出,获得一个thin 但是 deeper 的student model,因为一般深层的神经网络表达力更强,可以获得更加抽象的特征表征。

  • All previous work focuses on compressing a teacher network or an ensemble of networks into either networks of similar width and depth or into shallower and wider ones; not taking advantage of depth

  • allow the training of a student that is deeper and thinner than the teacher, using not only the outputs but also the intermediate representations learned by the teacher as hints to improve the training process and final performance of the student.
    在这里插入图片描述
    Hint layer和guided layer的定义:

  • A hint is defined as the output of a teacher’s hidden layer responsible for guiding the student’s learning process.

  • Analogously, we choose a hidden layer of the FitNet, the guided layer, to learn from the teacher’s hint layer.

  • We want the guided layer to be able to predict the output of the hint layer.

注意在student训练过程汇中引入hint是一种正则化的方式,hint/guided layer选层数越深,student网络训练的灵活性就越少,就会容易over-regularization。文中的hint/guided layer选的都是网络的中间层。

训练分为三个阶段:
在这里插入图片描述
在这里插入图片描述

  1. 训练teacher network, finet随机初始化
  2. hints training: 选取teacher的中间层作为guidance,对student的中间层进行监督学习,通常两者的维度不一样,所以需要一个额外的线性矩阵或卷积层 W r W_r 去进行维度变换,达到维度一致,然后使用L2 Loss进行监督学习(训练的是到guided layer为止的参数 W g u i d e d W_{guided} ):
    在这里插入图片描述
  3. 利用knowledge distillation去训练student的整个网络参数 W S W_S

五、知识蒸馏在推荐系统中的应用

该部分详细可见张俊林博士的文章: https://zhuanlan.zhihu.com/p/143155437

精排、粗排以及模型召回环节都可以采用知识蒸馏技术来优化现有推荐系统的性能和效果。

5.1 精排环节的知识蒸馏

对于精排,知识蒸馏适用于如下两种技术转换场景:

  • 排序模型从非DNN模型初次向DNN模型进行模型升级
    • 小的student model,对于非DNN可能效果也是有提升的
    • 在线服务速度快,可以明显降低模型升级的成本
  • 尽管线上已经采用了DNN 排序模型,但是模型还非常简单
    • 可以离线训练一个复杂但是效果明显优于线上简单DNN排序模块的模型作为Teacher
    • 可能在响应速度不降的前提下,在模型效果上有所提升
  • 线上模型已经很复杂,需要权衡:推荐质量和响应速度?

在这里插入图片描述

5.2 召回/粗排环节的知识蒸馏

知识蒸馏应用在召回/粗排环节是比较“合算”的,因为这两个环节本身,并不追求最高的推荐精度,而模型小,速度快则是模型召回及粗排的重要目标之一,这与知识蒸馏的特点正好相符合。

召回/粗排环节的知识蒸馏可以使用两阶段的方式训练,因为teacher model可以直接用精排环节已经训练好的模型。

此外,student model可以去学习模型的logits,也可以去学习精排模型的排序偏好。

联合训练召回、粗排及精排模型的设想
在这里插入图片描述

参考资料

  1. 如何理解soft target这一做法:https://www.zhihu.com/question/50519680/answer/136406661
  2. 知识蒸馏在推荐系统的应用: https://zhuanlan.zhihu.com/p/143155437