pytorch循环神经网络RNN从结构原理到应用实例

1、 RNN概述

人工神经网络和卷积神经网络的假设前提都是:元素之间是相互独立的 ,可是在生活中不少状况下这种假设并不成立,好比你写一段有意义的话 “碰见一我的只需1秒,喜欢一我的只需3,秒,爱上一我的只需1分钟,而我却用个人[?]在爱你。” ,做为正常人咱们知道这里应该填 “一辈子”,但之因此咱们会这样填是由于咱们读取了上下文,而普通的神经网络输入之间是相互独立的,网络没有记忆能力。扩展一下:训练样本是连续的序列且其长短不一,如一段连续的语音、一段连续的文本等,这些序列前面的输入与后面的输入有有必定的相关性,很难将其拆解为一个个单独的样原本进行DNN/CNN训练。html

循环神经网络(Recurrent Neural Networks,简称RNN)普遍应用于:算法

  • 语义分析(Semantic Analysis):按照语法分析器识别语法范畴进行语义检查和处理,产生相应的中间代码或者目标代码
  • 情感分析(Sentiment Classification)
  • 图像标注(Image Captioning):对图片进行文本描述
  • 语言翻译(Language Translation)

2、RNN网络结构及原理

图中各个参数意义:网络

1)x(t)表明在序列索引号t时训练样本的输入。一样的,x(t−1)x(t+1)表明在序列索引号t−1t+1时训练样本的输入。函数

2)h(t)表明在序列索引号t时模型的隐藏状态。h(t)x(t)h(t−1)共同决定。spa

3)o(t)表明在序列索引号t时模型的输出。o(t)只由模型当前的隐藏状态h(t)决定。.net

4)L(t)表明在序列索引号t时模型的损失函数。翻译

5)y(t)表明在序列索引号t时训练样本序列的真实输出。3d

6)U,W,V这三个矩阵是咱们的模型的线性关系参数,它在整个RNN网络中是共享的,这点和DNN很不相同。 也正由于是共享了,它体现了RNN的模型的“循环反馈”的思想。 [1]code

3、RNN前向传播原理

对于任何一个序列索引号t,隐藏状态\(h{(t)}\)\(h^{(t-1)}\)\(x^{(t)}\)获得:htm

\[h^{(t)} = \sigma(z^{(t)} = \sigma(Ux^{(t)}+Wh^{(t-1)}+b)) \]

其中σ为RNN的激活函数,b为偏置值(bias)

序列索引号为t的时候模型的输出\(o^{(t)}\)的表达式比较简单:

\[o^{(t)} = Vh^{(t)}+c \]

此时预测输出为:

\[\hat{y}^{(t)} = \sigma(o^{(t)}) \]

在上面这一过程当中使用了两次激活函数(第一次得到隐藏状态\(h^{(t)}\),第二次得到预测输出\(\hat{y}^{(t)}\))一般在第一次使用tanh激活函数,第二次使用softmax激活函数

4、RNN反向传播推导

RNN的法向传播经过梯度降低一次次迭代获得合适的参数U、W、V、b、c。在RNN中U、W、V、b、c参数在序列的各个位置都是相同的,反向传播咱们更新的是一样的参数。

对于RNN,咱们在序列的每个位置上都有损失,因此最终的损失L为:

\[L = \sum_{t=1}^{\tau}L^{(t)} \]

损失函数对更新的参数进行求偏导(注意咱们这里使用的两个激活函数分别为softmaxtanh,使用的偏差计算公式为交叉熵):

  • 首先考虑与损失函数直接相关的两个变量cV(即预测输出时的权值和偏置值),利用损失函数能够对这两个变量进行直接求偏导(即对softmax函数求导):

\[\frac{\partial{L}}{\partial{c}} = \sum_{t=1}^{\tau}\frac{\partial{L^{(t)}}}{\partial{c}} = \sum_{t =1}^{\tau}\hat{y}^{(t)}-y^{(t)} \]

\[\frac{\partial{L}}{\partial{V}} = \sum_{t=1}^{\tau}\frac{\partial{L^{(t)}}}{\partial{V}} = \sum_{t =1}^{\tau}(\hat{y}^{(t)}-y^{(t)})(h^{(t)})^T \]

  • 而损失函数对W、U、b的偏导数计算就比较复杂了:在反向传播时,某一序列位置t的梯度损失由当前位置的输出对应的梯度损失和序列索引位置t+1时的梯度损失两部分共同决定。
    从正向传播来看:

\[h^{(t+1)} = tanh(Ux^{(t+1)}+Wh^{(t)}+b)) \]

对于W、U、b在某一序列位置t的梯度损失须要反向传播一步步的计算。咱们定义序列索引t位置的隐藏状态的梯度为:

\[\delta^{(t)} = \frac{\partial{L}}{\partial{(h^{(t)})}} \]

\(\delta^{(\tau+1)}\)递推\(\delta^{(t)}\)

\[\delta^{(t)} = (\frac{\partial{\delta^{(t)}}}{\partial{h^{(t)}}})^T \frac{\partial{L}}{\partial{o^{(t)}}} + (\frac{\partial{h^{(t+1)}}}{\partial{h^{(t)}}})^T \frac{\partial{L}}{\partial{h^{(t+1)}}} = V^T(\hat{y}^{(t)}-y^{(t)}) +W^Tdiag(1-(h^{(t+1)})^2)\delta^{(t+1)} \]

对于\(\delta{(\tau)}\),其后面没有其余的索引(最后一个输入),所以:

\[\delta^{(\tau)} = (\frac{\partial{\delta^{(\tau)}}}{\partial{h^{(\tau)}}})^T \frac{\partial{L}}{\partial{o^{(\tau)}}} = V^T(\hat{y}^{(\tau)}-y^{(t)}) \]

根据\(\delta{(t)}\),咱们就能够计算W、U、b了:

\[\frac{\partial{L}}{\partial{W}} = \sum_{t=1}^{\tau}diag(1-(h^{(t)})^2)\delta^{(t)}(h^{(t-1)})^T \]

\[\frac{\partial{L}}{\partial{b}} = \sum_{t=1}^{\tau}diag(1-(h^{(t)})^2)\delta^{(t)} \]

\[\frac{\partial{L}}{\partial{V}} = \sum_{t=1}^{\tau}diag(1-(h^{(t)})^2)\delta^{(t)}(x^{(t)})^T \]

5、RNN梯度消失问题


假设时间序列只有三段,\(S_0\)为给定值,神经元没有激活函数,而RNN按照最简单的前向传播:

\[S_1 = W_xX_1 + W_sS_0+b_1 ; O_1 = W_0S_1 +b2 \]

\[S_2 = W_xX_2 + W_sS_1+b_1 ; O_2 = W_0S_2 +b2 \]

\[S_3 = W_xX_3 + W_sS_2+b_1 ; O_3 = W_0S_3 +b2 \]

假设在t=3时刻,损失函数为$$L_3 = \frac{1}{2}(Y_3-O_3)^2$$
对于一次训练,其损失函数值是累加的:$$L = \sum_{t = 0}{T}L_t$$
此处利用反向传播公式仅对Wx、Ws、W0求偏导数(Wx、Ws与输出Output相关,并不是直接求损失函数Loss的偏导,在第四部分也已经说明了:

\[\frac{\partial{L}_3}{\partial{W}_0} = \frac{\partial{L}_3}{\partial{O}_3} \frac{\partial{O}_3}{\partial{W}_0} \]

\[\frac{\partial{L}_3}{\partial{W}_x} = \frac{\partial{L}_3}{\partial{O}_3} \frac{\partial{O}_3}{\partial{S}_3} \frac{\partial{S}_3}{\partial{W}_x} + \frac{\partial{L}_3}{\partial{O}_3} \frac{\partial{O}_3}{\partial{S}_3} \frac{\partial{S}_3}{\partial{S}_2}\frac{\partial{S}_2}{\partial{W}_x}+ \frac{\partial{L}_3}{\partial{O}_3} \frac{\partial{O}_3}{\partial{S}_3} \frac{\partial{S}_3}{\partial{S}_2}\frac{\partial{S}_2}{\partial{S}_1}\frac{\partial{S}_1}{\partial{w}_x} \]

\[\frac{\partial{L}_3}{\partial{W}_s} = \frac{\partial{L}_3}{\partial{O}_3} \frac{\partial{O}_3}{\partial{S}_3} \frac{\partial{S}_3}{\partial{W}_s} + \frac{\partial{L}_3}{\partial{O}_3} \frac{\partial{O}_3}{\partial{S}_3} \frac{\partial{S}_3}{\partial{S}_2}\frac{\partial{S}_2}{\partial{W}_s}+ \frac{\partial{L}_3}{\partial{O}_3} \frac{\partial{O}_3}{\partial{S}_3} \frac{\partial{S}_3}{\partial{S}_2}\frac{\partial{S}_2}{\partial{S}_1}\frac{\partial{S}_1}{\partial{w}_s} \]

从这冗长的公式中能够看见用梯度降低法对损失函数求W0的偏导数其没有很长的依赖(就是公式很短、求解简单)可是对于WxWs的公式就很是长了,上面仅仅推到了三层网络结构就已经如此繁杂了,推导任意时刻损失函数关于WxWs的偏导数公式:

\[\frac{\partial{L}_t}{\partial{W}_x} = \sum_{k=0}^{t}\frac{\partial{L}_t}{\partial{O}_t}\frac{\partial{O}_t}{\partial{S}_t}(\prod_{j=k+1}^{t}\frac{\partial{S}_j}{\partial{S}_{j-1}})\frac{\partial{S}_k}{\partial{W}_x} \]

\[\frac{\partial{L}_t}{\partial{W}_s} = \sum_{k=0}^{t}\frac{\partial{L}_t}{\partial{O}_t}\frac{\partial{O}_t}{\partial{S}_t}(\prod_{j=k+1}^{t}\frac{\partial{S}_j}{\partial{S}_{j-1}})\frac{\partial{S}_k}{\partial{W}_s} \]

若是再加上激活函数:$$S_j = tanh(W_xX_j + W_sS_{j-1}+b_1)$$

则$$\prod_{j=k+1}^{t}\frac{\partial{S}j}{\partial{S}{j-1}} = \prod_{j=k+1}^{t}W_s tanh^{'}$$

激活函数tanh[2]:

\[f(x) = tanh(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}} \]

tanh函数导数:

\[f(x)^{'} = 1 - (tanh(x))^2 \]

tanh函数及其导数

根据激活函数及其导数的图像可见 [3]

  • \[tanh^{'}(x) ≤ 1 \]

  • 绝大部分状况下,tanh的导数都是小于1的。不多状况出现:

\[W_xX_j + W_sS_{j-1} + b_1 = 0 \]

  • 若是Ws是一个大于0小于1的值,当t很大的时候

\[\prod_{j=k+1}^{t}W_s tanh^{'} --> 0 \]

  • 若是Ws是一个很大的值,当t很大的时候

\[\prod_{j=k+1}^{t}W_s tanh^{'} --> ∞ \]

6、消除梯度爆炸和梯度消失

在公式:

\[\frac{\partial{L}_t}{\partial{W}_x} = \sum_{k=0}^{t}\frac{\partial{L}_t}{\partial{O}_t}\frac{\partial{O}_t}{\partial{S}_t}(\prod_{j=k+1}^{t}\frac{\partial{S}_j}{\partial{S}_{j-1}})\frac{\partial{S}_k}{\partial{W}_x} \]

\[\frac{\partial{L}_t}{\partial{W}_s} = \sum_{k=0}^{t}\frac{\partial{L}_t}{\partial{O}_t}\frac{\partial{O}_t}{\partial{S}_t}(\prod_{j=k+1}^{t}\frac{\partial{S}_j}{\partial{S}_{j-1}})\frac{\partial{S}_k}{\partial{W}_s} \]

致使梯度消失和梯度爆炸的缘由在于:

\[\prod_{j=k+1}^{t}\frac{\partial{S}_j}{\partial{S}_{j-1}} \]

消除这个部分的影响一个考虑是使得

\[\frac{\partial{S}_j}{\partial{S}_{j-1}} ≈ 1 \]

另外一种是使得:

\[\frac{\partial{S}_j}{\partial{S}_{j-1}} ≈ 0 \]


  1. 循环神经网络(RNN)模型与前向反向传播算法 ↩︎

  2. Tanh激活函数及求导过程 ↩︎

  3. RNN梯度消失和爆炸的缘由 ↩︎

相关文章
相关标签/搜索