AWD-LSTM为何这么棒?

摘要: AWD-LSTM为何这么棒,看完你就明白啦!

AWD-LSTM是目前最优秀的语言模型之一。在众多的顶会论文中,对字级模型的研究都采用了AWD-LSTMs,而且它在字符级模型中的表现也一样出色。算法

本文回顾了论文——Regularizing and Optimizing LSTM Language Models ,在介绍AWD-LSTM模型的同时并解释其中所涉及的各项策略。该论文提出了一系列基于词的语言模型的正则化和优化策略。这些策略不只行之有效,并且可以在不改变现有LSTM模型的基础上使用。网络

AWD-LSTM即ASGD Weight-Dropped LSTM。它使用了DropConnect及平均随机梯度降低的方法,除此以外还有包含一些其它的正则化策略。咱们将在后文详细讲解这些策略。本文将着重于介绍它们在语言模型中的成功应用。函数

实验代码获取:awd-lstm-lm GitHub repository性能

LSTM中的数学公式:学习

it = σ(Wixt + Uiht-1)优化

ft = σ(Wfxt + Ufht-1)spa

ot = σ(Woxt + Uoht-1)3d

c’t = tanh(Wcxt + Ucht-1)blog

ct = it ⊙ c’t + ft ⊙ c’t-1ip

ht = ot ⊙ tanh(ct)

其中, Wi, Wf, Wo, Wc, Ui, Uf, Uo, Uc都是权重矩阵,xt表示输入向量,ht表示隐藏单元向量,ct表示单元状态向量, ⊙表示element-wise乘法。

接下来咱们将逐一介绍做者提出的策略:

权重降低的LSTM

RNN的循环链接容易致使过拟合问题,如何解决这一问题也成了一个较为热门的研究领域。Dropouts的引入在前馈神经网络和卷积网络中取得了巨大的成功。但将Dropouts引入到RNN中却反响甚微,这是因为Dropouts的加入破坏了RNN长期依赖的能力。

研究学者们就此提出了许多解决方案,可是这些方法要么做用于隐藏状态向量ht-1,要么是对单元状态向量ct进行更新。上述操做可以解决高度优化的“黑盒”RNN,例如NVIDIA’s cuDNN LSTM中的过拟合问题。

但仅如此是不够的,为了更好的解决这个问题,研究学者们引入了DropConnect。DropConnect是在神经网络中对全链接层进行规范化处理。Dropout是指在模型训练时随机的将隐层节点的权重变成0,暂时认为这些节点不是网络结构的一部分,可是会把它们的权重保留下来。与Dropout不一样的是DropConnect在训练神经网络模型过程当中,并不随机的将隐层节点的输出变成0,而是将节点中的每一个与其相连的输入权值以1-p的几率变成0。

clipboard.png

DropConnect做用在hidden-to-hidden权重矩阵(Ui、Uf、Uo、Uc)上。在前向和后向遍历以前,只执行一次dropout操做,这对训练速度的影响较小,能够用于任何标准优化的“黑盒”RNN中。经过对hidden-to-hidden权重矩阵进行dropout操做,能够避免LSTM循环链接中的过分拟合问题。

你能够在 awd-lstm-lm 中找到weight_drop.py 模块用于实现。

做者表示,尽管DropConnect是经过做用在hidden-to-hidden权重矩阵以防止过拟合问题,但它也能够做用于LSTM的非循环权重。

使用非单调条件来肯定平均触发器

研究发现,对于特定的语言建模任务,传统的不带动量的SGD算法优于带动量的SGD、Adam、Adagrad及RMSProp等算法。所以,做者基于传统的SGD算法提出了ASGD(Average SGD)算法。

Average SGD

ASGD算法采用了与SGD算法相同的梯度更新步骤,不一样的是,ASGD没有返回当前迭代中计算出的权值,而是考虑的这一步和前一次迭代的平均值。

传统的SGD梯度更新:

clipboard.png

AGSD梯度更新:

clipboard.png

其中,k是在加权平均开始以前运行的最小迭代次数。在k次迭代开始以前,ASGD与传统的SGD相似。t是当前完成的迭代次数,sum(w_prevs)是迭代k到t的权重之和,lr_t是迭代次数t的学习效率,由学习率调度器决定。

你能够在这里找到AGSD的PyTorch实现。

但做者也强调,该方法有以下两个缺点:

• 学习率调度器的调优方案不明确

• 如何选取合适的迭代次数k。值过小会对方法的有效性产生负面影响,值太大可能须要额外的迭代才能收敛。

基于此,做者在论文中提出了使用非单调条件来肯定平均触发器,即NT-ASGD,其中:

• 当验证度量不能改善多个循环时,就会触发平均值。这是由非单调区间的超参数n保证的。所以,每当验证度量没有在n个周期内获得改进时,就会使用到ASGD算法。经过实验发现,当n=5的时候效果最好。

• 整个实验中使用一个恒定的学习速率,不须要进一步的调整。

正则化方法

除了上述说起的两种方法外,做者还使用了一些其它的正则化方法防止过拟合问题及提升数据效率。

长度可变的反向传播序列

做者指出,使用固定长度的基于时间的反向传播算法(BPTT)效率较低。试想,在一个时间窗口大小固定为10的BPTT算法中,有100个元素要进行反向传播操做。在这种状况下,任何能够被10整除的元素都不会有能够反向支撑的元素。这致使了1/10的数据没法以循环的方式进行自我改进,8/10的数据只能使用到部分的BPTT窗口。

为了解决这个问题,做者提出了使用可变长度的反向传播序列。首先选取长度为bptt的序列,几率为p以及长度为bptt/2的序列,几率为1-p。在PyTorch中,做者将p设为0.95。

clipboard.png

其中,base_bptt用于获取seq_len,即序列长度,在N(base_bptt, s)中,s表示标准差,N表示服从正态分布。代码以下:

clipboard.png

学习率会根据seq_length进行调整。因为当学习速率固定时,会更倾向于对段序列而非长序列进行采样,因此须要进行缩放。

clipboard.png

Variational Dropout

在标准的Dropout中,每次调用dropout链接时都会采样到一个新的dropout mask。而在Variational Dropout中,dropout mask在第一次调用时只采样一次,而后locked dropout mask将重复用于前向和后向传播中的全部链接。

虽然使用了DropConnect而非Variational Dropout以规范RNN中hidden-to-hidden的转换,可是对于其它的dropout操做均使用的Variational Dropout,特别是在特定的前向和后向传播中,对LSTM的全部输入和输出使用相同的dropout mask。

点击查看官方awd-lstm-lm GitHub存储库的Variational dropout实现。详情请参阅原文。

Embedding Dropout

论文中所提到的Embedding Dropout首次出如今——《A Theoretically Grounded Application of Dropout in Recurrent Neural Networks》一文中。该方法是指将dropout做用于嵌入矩阵中,且贯穿整个前向和反向传播过程。在该过程当中出现的全部特定单词均会消失。

Weight Tying(权重绑定)

权重绑定共享嵌入层和softmax层之间的权重,可以减小模型中大量的参数。

Reduction in Embedding Size

对于语言模型来讲,想要减小总参数的数量,最简单的方法是下降词向量的维数。即便这样没法帮助缓解过拟合问题,但它可以减小嵌入层的维度。对LSTM的第一层和最后一层进行修改,可使得输入和输出的尺寸等于减少后的嵌入尺寸。

Activation Regularization(激活正则化)

L2正则化是对权重施加范数约束以减小过拟合问题,它一样能够用于单个单元的激活,即激活正则化。激活正则化可做为一种调解网络的方法。

clipboard.png

Temporal Activation Regularization(时域激活正则化)

同时,L2正则化能对RNN在不一样时间步骤上的输出差值进行范数约束。它经过在隐藏层产生较大变化对模型进行惩罚。

clipboard.png

其中,alpha和beta是缩放系数,AR和TAR损失函数仅对RNN最后一层的输出起做用。

模型分析

做者就上述模型在不一样的数据集中进行了实验,为了对分分析,每次去掉一种策略。

clipboard.png

图中的每一行表示去掉特定策略的困惑度(perplexity)分值,从该图中咱们可以直观的看出各策略对结果的影响。

实验细节

数据——来自Penn Tree-bank(PTB)数据集和WikiText-2(WT2)数据集。

网络体系结构

——全部的实验均使用的是3层LSTM模型。

批尺寸——WT2数据集的批尺寸为80,PTB数据集的批尺寸为40。根据以往经验来看,较大批尺寸(40-80)的性能优于较小批尺寸(10-20)。

其它超参数的选择请参考原文。

总结

该论文很好的总结了现有的正则化及优化策略在语言模型中的应用,对于NLP初学者甚至研究者都大有裨益。论文中强调,虽然这些策略在语言建模中得到了成功,但它们一样适用于其余序列学习任务。

本文做者:【方向】

阅读原文

本文为云栖社区原创内容,未经容许不得转载。

相关文章
相关标签/搜索