精准率和召回率是两个不一样的评价指标,不少时候它们之间存在着差别,具体在使用的时候如何解读精准率和召回率,应该视具体使用场景而定算法
有些场景,人们可能更注重精准率,如股票预测系统,咱们定义股票升为1,股票降为0,咱们更关心的是将来升的股票的比例,而在另一些场景中,人们更加注重召回率,如癌症预测系统,定义健康为1,患病为0,咱们更关心癌症患者检查的遗漏状况。编程
F1 Score 兼顾精准率和召回率,它是二者的调和平均值app
\[\frac{1}{F1} = \frac{1}{2}(\frac{1}{Precision} + \frac{1}{recall})\]
\[F1 = \frac{2\cdot precision\cdot recall}{precision+recall}\]
定义F1 Score测试
def f1_score(precision,recall): try: return 2*precision*recall/(precision+recall) except: return 0
由上看出,F1 Score更偏向于分数小的那个指标spa
精准率和召回率是两个互相矛盾的目标,提升一个指标,另外一个指标就会不可避免的降低。如何达到二者之间的一个平衡呢?code
回忆逻辑回归算法的原理:将一个结果发生的几率大于0.5,就把它分类为1,发生的几率小于0.5,就把它分类为0,决策边界为:\(\theta ^T \cdot X_b = 0\)blog
这条直线或曲线决定了分类的结果,平移决策边界,使\(\theta ^T \cdot X_b\)不等于0而是一个阈值:\(\theta ^T \cdot X_b = threshold\)ci
圆形表明分类结果为0,五角星表明分类结果为1,由上图能够看出,精准率和召回率是两个互相矛盾的指标,随着阈值的逐渐增大,召回率逐渐下降,精准率逐渐增大。it
编程实现不一样阀值下的预测结果及混淆矩阵io
from sklearn.linear_model import LogisticRegression # 数据使用前一节处理后的手写识别数据集 log_reg = LogisticRegression() log_reg.fit(x_train,y_train)
求每一个测试数据在逻辑回归算法中的score值:
decision_score = log_reg.decision_function(x_test)
不一样阀值下预测的结果
y_predict_1 = numpy.array(decision_score>=-5,dtype='int') y_predict_2 = numpy.array(decision_score>=0,dtype='int') y_predict_3 = numpy.array(decision_score>=5,dtype='int')
查看不一样阈值下的混淆矩阵:
求出0.1步长下,阈值在[min,max]区间下的精准率和召回率,查看其曲线特征:
threshold_scores = numpy.arange(numpy.min(decision_score),numpy.max(decision_score),0.1) precision_scores = [] recall_scores = [] # 求出每一个分类阈值下的预测值、精准率和召回率 for score in threshold_scores: y_predict = numpy.array(decision_score>=score,dtype='int') precision_scores.append(precision_score(y_test,y_predict)) recall_scores.append(recall_score(y_test,y_predict))
画出精准率和召回率随阈值变化的曲线
plt.plot(threshold_scores,precision_scores) plt.plot(threshold_scores,recall_scores) plt.show()
画出精准率-召回率曲线
plt.plot(precision_scores,recall_scores) plt.show()
from sklearn.metrics import precision_recall_curve precisions,recalls,thresholds = precision_recall_curve(y_test,decision_score) # sklearn中最后一个精准率为1,召回率为0,没有对应的threshold plt.plot(thresholds,precisions[:-1]) plt.plot(thresholds,recalls[:-1]) plt.show()