最近在使用sklearn作分类时候,用到metrics中的评价函数,其中有一个很是重要的评价函数是F1值,(关于这个值的原理自行google或者百度)函数
在sklearn中的计算F1的函数为 f1_score ,其中有一个参数average用来控制F1的计算方式,今天咱们就说说当参数取micro和macro时候的区别测试
一、F1公式描述:google
F1-score: 2*(P*R)/(P+R)spa
二、 f1_score中关于参数average的用法描述:code
'micro'
:Calculate metrics globally by counting the total true positives, false negatives and false positives.blog
'micro':经过先计算整体的TP,FN和FP的数量,再计算F1ci
'macro'
:Calculate metrics for each label, and find their unweighted mean. This does not take label imbalance into account.it
'macro':分布计算每一个类别的F1,而后作平均(各种别F1的权重相同)io
三、初步理解table
经过参数用法描述,想必你们从字面层次也能理解他是什么意思,micro就是先计算全部的TP,FN , FP的个数后,而后再利上文提到公式计算出F1
macro其实就是先计算出每一个类别的F1值,而后去平均,好比下面多分类问题,总共有1,2,3,4这4个类别,咱们能够先算出1的F1,2的F1,3的F1,4的F1,而后再取平均(F1+F2+F3+4)/4
y_true = [1, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 4, 4]
y_pred = [1, 1, 1, 0, 0, 2, 2, 3, 3, 3, 4, 3, 4, 3]
四、进一步理解
咱们仍是以上面的例子为例说明sklearn中是如何计算micro 和 macro的:
micro计算原理
首先计算总TP值,这个很好就算,就是数一数上面有多少个类别被正确分类,好比1这个类别有3个分正确,2有2个,3有2个,4有1个,那TP=3+2+2+1=8
其次计算总FP值,简单的说就是不属于某一个类别的元数被分到这个类别的数量,好比上面不属于4类的元素被分到4的有1个
若是还比较迷糊,咱们在计算时候能够把4保留,其余全改为0,就能够更加清楚地看出4类别下面的FP数量了,其实这个原理就是 One-vs-All (OvA),把4当作正类,其余看出负类
同理咱们能够再计算FN的数量
1类 | 2类 | 3类 | 4类 | 总数 | |
TP | 3 | 2 | 2 | 1 | 8 |
FP | 0 | 0 | 3 | 1 | 4 |
FN | 2 | 2 | 1 | 1 | 6 |
因此micro的 精确度P 为 TP/(TP+FP)=8/(8+4)=0.666 召回率R TP/(TP+FN)=8/(8+6)=0.571 因此F1-micro的值为:0.6153
能够用sklearn来核对,把average设置成micro
y_true = [1, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 4, 4] y_pred = [1, 1, 1, 0, 0, 2, 2, 3, 3, 3, 4, 3, 4, 3] print(f1_score(y_true,y_pred,labels=[1,2,3,4],average='micro'))
#>>> 0.615384615385
计算macro
macro先要计算每个类的F1,有了上面那个表,计算各个类的F1就很容易了,好比1类,它的精确率P=3/(3+0)=1 召回率R=3/(3+2)=0.6 F1=2*(1*0.5)/1.5=0.75
能够sklearn,来计算核对,把average设置成macro
#average=None,取出每一类的P,R,F1值
p_class, r_class, f_class, support_micro=precision_recall_fscore_support(y_true=y_true, y_pred=y_pred, labels=[1, 2, 3, 4], average=None) print('各种单独F1:',f_class) print('各种F1取平均:',f_class.mean()) print(f1_score(y_true,y_pred,labels=[1,2,3,4],average='macro')) #>>>各种单独F1: [ 0.75 0.66666667 0.5 0.5 ] #>>>各种F1取平均: 0.604166666667 #>>>0.604166666667
若有装载,请注明出处,谢谢