【转载】深度学习中softmax交叉熵损失函数的理解

深度学习中softmax交叉熵损失函数的理解

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处连接和本声明。
本文连接: https://blog.csdn.net/lilong117194/article/details/81542667

1. softmax层的做用

经过神经网络解决多分类问题时,最经常使用的一种方式就是在最后一层设置n个输出节点,不管在浅层神经网络仍是在CNN中都是如此,好比,在AlexNet中最后的输出层有1000个节点,即使是ResNet取消了全链接层,但1000个节点的输出层还在。git

通常状况下,最后一个输出层的节点个数与分类任务的目标数相等。 
假设最后的节点数为N,那么对于每个样例,神经网络能够获得一个N维的数组做为输出结果,数组中每个维度会对应一个类别。在最理想的状况下,若是一个样本属于k,那么这个类别所对应的的输出节点的输出值应该为1,而其余节点的输出都为0,即 [0,0,1,0,.0,0][0,0,1,0,….0,0],这个数组也就是样本的Label,是神经网络最指望的输出结果,但实际是这样的输出[0.01,0.01,0.6,....0.02,0.01][0.01,0.01,0.6,....0.02,0.01],这实际上是在原始输出的基础上加入了softmax的结果,原始的输出是输入的数值作了复杂的加权和与非线性处理以后的一个值而已,这个值能够是任意的值,可是通过softmax层后就成了一个几率值,并且几率和为1。 
假设神经网络的原始输出为y_1,y_2,….,y_n,那么通过Softmax回归处理以后的输出为 : 
数组

 
y=softmax(yi)=eyinj=1eyjy′=softmax(yi)=eyi∑j=1neyj

以上能够看出: y=1∑y′=1 
这也是为何softmax层的每一个节点的输出值成为了几率和为1的几率分布。

 

2. 交叉熵损失函数的数学原理

上面说过实际的指望输出,也就是标签是[0,0,1,0,.0,0][0,0,1,0,….0,0]这种形式,而实际的输出是[0.01,0.01,0.6,....0.02,0.01][0.01,0.01,0.6,....0.02,0.01]这种形式,这时按照常理就须要有一个损失函数来断定实际输出和指望输出的差距,交叉熵就是用来断定实际的输出与指望的输出的接近程度!下面就简单介绍下交叉熵的原理。markdown

交叉熵刻画的是实际输出(几率)与指望输出(几率)的距离,也就是交叉熵的值越小,两个几率分布就越接近。假设几率分布p为指望输出(标签),几率分布q为实际输出,H(p,q)为交叉熵。网络

  • 第一种交叉熵损失函数的形式: 
     
    H(p,q)=xp(x)logq(x)H(p,q)=−∑xp(x)logq(x)

举个例子: 
假设N=3,指望输出为p=(1,0,0),实际输出q1=(0.5,0.2,0.3)q2=(0.8,0.1,0.1)q1=(0.5,0.2,0.3),q2=(0.8,0.1,0.1),这里的q1,q2两个输出分别表明在不一样的神经网络参数下的实际输出,经过计算其对应的交叉熵来优化神经网络参数,计算过程: 
H(p,q1)=1(1×log0.5+0×log0.2+0×log0.3)H(p,q1)=−1(1×log0.5+0×log0.2+0×log0.3) 
假设结果:H(p,q1)=0.3H(p,q1)=0.3 
H(p,q2)=1(1×log0.8+0×log0.1+0×log0.1)H(p,q2)=−1(1×log0.8+0×log0.1+0×log0.1) 
假设结果:H(p,q2)=0.1H(p,q2)=0.1 
这时获得了q2q2是相对正确的分类结果。
session

  • 第二种交叉熵损失函数形式: 
     
    H(p,q)=x(p(x)logq(x)+(1p(x))log(1q(x)))H(p,q)=−∑x(p(x)logq(x)+(1−p(x))log(1−q(x)))

    下面简单推到其过程: 
    咱们知道,在二分类问题模型:例如逻辑回归「Logistic Regression」、神经网络「Neural Network」等,真实样本的标签为 [0,1],分别表示负类和正类。模型的最后一般会通过一个 Sigmoid 函数,输出一个几率值,这个几率值反映了预测为正类的可能性:几率越大,可能性越大。 
    Sigmoid 函数的表达式和图形以下所示:g(s)=11+esg(s)=11+e−s 
    其中 s 是模型上一层的输出,Sigmoid 函数有这样的特色:s = 0 时,g(s) = 0.5;s >> 0 时, g ≈ 1,s << 0 时,g ≈ 0。显然,g(s) 将前一级的线性输出映射到 [0,1] 之间的数值几率上。 
    其中预测输出即 Sigmoid 函数的输出g(s)表征了当前样本标签为 1 的几率: 
    P(y=1|x)=y^P(y=1|x)=y^ 
    p(y=0|x)=1y^p(y=0|x)=1−y^ 
    这个时候从极大似然性的角度出发,把上面两种状况整合到一块儿: 
    p(y|x)=y^y(1y^)(1y)p(y|x)=y^y(1−y^)(1−y) 
    这个函数式表征的是: 
    当真实样本标签 y = 1 时,上面式子第二项就为 1,几率等式转化为: 
    P(y=1|x)=y^P(y=1|x)=y^ 
    当真实样本标签 y = 0 时,上面式子第一项就为 1,几率等式转化为: 
    P(y=0|x)=1y^P(y=0|x)=1−y^ 
    两种状况下几率表达式跟以前的彻底一致,只不过咱们把两种状况整合在一块儿了。那这个时候应用极大似然估计应该获得的是全部的几率值乘积应该最大,即: 
    L=Ni=1y^yii(1y^i)(1yi)L=∑i=1Ny^iyi(1−y^i)(1−yi) 
    引入log函数后获得: 
    L=log(L)=Ni=1yilogy^i+(1yi)log(1y^i)L′=log(L)=∑i=1Nyilogy^i+(1−yi)log(1−y^i) 
    这时令loss=-log(L)=-L',也就是损失函数越小越好,而此时也就是 L'越大越好。

而在实际的使用训练过程当中,数据每每是组合成为一个batch来使用,因此对用的神经网络的输出应该是一个m*n的二维矩阵,其中m为batch的个数,n为分类数目,而对应的Label也是一个二维矩阵,仍是拿上面的数据,组合成一个batch=2的矩阵 函数

 
q=[0.50.80.20.10.30.1]q=[0.50.20.30.80.10.1]

 
p=[110000]p=[100100]

根据第一种交叉熵的形式获得: 
 
H(p,q)=[0.30.1]H(p,q)=[0.30.1]

而对于一个batch,最后取平均为0.2。

 

3. 在TensorFlow中实现交叉熵

在TensorFlow能够采用这种形式:学习

cross_entropy = -tf.reduce_mean(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0))) 
  • 1

其中y_表示指望的输出,y表示实际的输出(几率值),*为矩阵元素间相乘,而不是矩阵乘。 
而且经过tf.clip_by_value函数能够将一个张量中的数值限制在一个范围以内,这样能够避免一些运算错误(好比log0是无效的),tf.clip_by_value函数是为了限制输出的大小,为了不log0为负无穷的状况,将输出的值限定在(1e-10, 1.0)之间,其实1.0的限制是没有意义的,由于几率怎么会超过1呢。好比:优化

import tensorflow as tf

v=tf.constant([[1.0,2.0,3.0],[4.0,5.0,6.0]]) with tf.Session() as sess: print(tf.clip_by_value(v,2.5,4.5).eval(session=sess))
  • 1
  • 2
  • 3
  • 4
  • 5

结果:ui

[[2.5 2.5 3. ] [4. 4.5 4.5]]
  • 1
  • 2

上述代码实现了第一种形式的交叉熵计算,须要说明的是,计算的过程其实和上面提到的公式有些区别,按照上面的步骤,平均交叉熵应该是先计算batch中每个样本的交叉熵后取平均计算获得的,而利用tf.reduce_mean函数其实计算的是整个矩阵的平均值,这样作的结果会有差别,可是并不改变实际意义。atom

import tensorflow as tf

v=tf.constant([[1.0,2.0,3.0],[4.0,5.0,6.0]]) with tf.Session() as sess: # 输出3.5 print(tf.reduce_mean(v).eval())
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

因为在神经网络中,交叉熵经常与Sorfmax函数组合使用,因此TensorFlow对其进行了封装,即:

cross_entropy = tf.nn.sorfmax_cross_entropy_with_logits(y_ ,y)
  • 1

与第一个代码的区别在于,这里的y用神经网络最后一层的原始输出就行了,而不是通过softmax层的几率值。

参考:http://www.javashuo.com/article/p-qrfavtho-ev.html 
https://blog.csdn.net/chaipp0607/article/details/73392175

相关文章
相关标签/搜索