模型的选择与调优:交叉验证与网格搜索

击上方
“蓝色字”
可关注咱们!


今日分享:交叉验证与网格搜索 web


一:交叉验证算法

交叉验证:为了让被评估的模型更加准确可信数组

交叉验证过程:将拿到的数据,分为训练集和验证集(注意这里的数据是在训练集中进行划分的,也就是将原始数据划分获得的训练集再次划分为训练集和验集),交叉验证通常结合网格搜索使用。微信

如下图为例:将数据分红5份,其中一份做为验证集。而后通过5次(组)的测试,每次都更换不一样的验证集,又称5折交叉验证经过对上面的五组数据分别进行建模,每一组都会获得一个模型的精确度,而后取平均值做为该模型最后的精确度,而后再进行后续的步骤。学习


(五折交叉验证)测试


二:超参数搜索-网格搜索spa

一般状况下,有不少参数是须要手动指定的(如k-近邻算法中的K值),这种叫超参数。可是手动过程繁杂,因此须要对模型预设几种超参数组合。每组超参数都采用交叉验证来进行评估。最后选出最优参数组合创建模型。.net



对于超参数较少的K-近邻来讲,也许能够经过 for 循环开找到较优的k值,可是对于其余的模型,若是须要调的参数较多,for循环就不太方便了,好比两个超参数时,每一个参数分别指定4个值,则下来就有16中参数组成的模型。3d


三:网格搜索APIcode


sklearn.model_selection.GridSearchCV


四:API参数介绍


sklearn.model_selection.GridSearchCV(estimator, param_grid=None,cv=None)
对估计器的指定参数值进行详尽搜索

estimator:估计器对象 就是哪个模型

param_grid:估计器参数(dict){“n_neighbors”:[1,3,5]}

cv:指定几折交叉验证

fit:输入训练数据

score:准确率

结果分析:
best_score_:在交叉验证中测试的最好结果
best_estimator_:最好的参数模型
cv_results_:每次交叉验证后的测试集准确率结果和训练集准确率结果


五:K-近邻网格搜索


使用鸢尾花数据集来进行网格搜索示例演示


#导入相关库
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import GridSearchCV


def knn_iris():
   '''K-近邻模型对鸢尾花进行分类'''
   
   #加载数据集
   iris = load_iris()
   
   #划分数据集
   #切记 x_train,x_test,y_train,y_test 顺序位置必定不能写错
   #括号中参数分别为 (特征值 目标值 测试集大小占比) 占比可自行设定 经常使用0.25
   x_train,x_test,y_train,y_test = train_test_split(iris.data,iris.target,test_size=0.25)
   
   '''特征工程(标准化)'''
   std = StandardScaler()
   
   #对测试集和训练集的特征值进行标准化
   x_train = std.fit_transform(x_train)

   x_test = std.transform(x_test)
   
   #Knn模型实例化
   knn = KNeighborsClassifier()
   
   # 以字典形式构造一些参数的值进行搜索,若存在别的参数时,只需添加相应的键值
   # 这里指定参数为 1,3,5,7,10
   param = {"n_neighbors": [1,3,5,7,10]}

   # 进行网格搜索 3折交叉验证
   gc = GridSearchCV(knn, param_grid=param, cv=3)

   gc.fit(x_train, y_train)

   print("每一个超参数每次交叉验证的结果:\n", gc.cv_results_)
   
   print("在测试集上准确率:\n", gc.score(x_test, y_test))

   print("在交叉验证当中最好的结果:\n", gc.best_score_)

   print("选择最好的模型是:\n", gc.best_estimator_)

   
if __name__ == '__main__':
   knn_iris()


输出结果

每一个超参数每次交叉验证的结果:
{'split1_train_score': array([1.        , 0.97333333, 0.97333333, 0.97333333, 0.94666667]), 'params': [{'n_neighbors': 1}, {'n_neighbors': 3}, {'n_neighbors': 5}, {'n_neighbors': 7}, {'n_neighbors': 10}], 'std_train_score': array([0.        , 0.01110803, 0.00637203, 0.00637203, 0.0108896 ]), 'mean_train_score': array([1.        , 0.97315315, 0.97765766, 0.97765766, 0.95981982]),
'split2_train_score': array([1.        , 0.98666667, 0.98666667, 0.98666667, 0.97333333]), 'mean_test_score': array([0.94642857, 0.95535714, 0.96428571, 0.95535714, 0.96428571]),
'split0_train_score': array([1.        , 0.95945946, 0.97297297, 0.97297297, 0.95945946]),
'split2_test_score': array([0.89189189, 0.94594595, 0.97297297, 0.91891892, 0.94594595]),
'split0_test_score': array([1., 1., 1., 1., 1.]), 'mean_fit_time': array([0.00100843, 0.00099985, 0.        , 0.        , 0.        ]), 'rank_test_score': array([5, 3, 1, 3, 1]), 'param_n_neighbors': masked_array(data=[1, 3, 5, 7, 10],
            mask=[False, False, False, False, False],
      fill_value='?',
           dtype=object), 'std_test_score': array([0.04423072, 0.03382427, 0.03372858, 0.03382427, 0.02559281]), 'mean_score_time': array([0.00100978, 0.00100025, 0.        , 0.        , 0.        ]), 'std_score_time': array([1.46122043e-05, 4.89903609e-07, 0.00000000e+00, 0.00000000e+00,
      0.00000000e+00]), 'std_fit_time': array([3.53483630e-05, 2.24783192e-07, 0.00000000e+00, 0.00000000e+00,
      0.00000000e+00]),
'split1_test_score': array([0.94594595, 0.91891892, 0.91891892, 0.94594595, 0.94594595])}

在测试集上准确率:
0.9210526315789473

在交叉验证当中最好的结果:
0.9642857142857143

选择最好的模型是:
KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
          metric_params=None, n_jobs=1, n_neighbors=5, p=2,
          weights='uniform')


由输出结果可知,在指定的几个k值中,超参数k=5时,模型效果最好




Python基础知识集锦

爬虫专题文章整理篇!!!

Python数据分析干货整理篇

Matplotlib数据可视化专题集锦贴



公众号     QQ群

扫QQ群二维码进交流学习群

或在后台回复:加群

本文分享自微信公众号 - 数据指南(BigDataDT)。
若有侵权,请联系 support@oschina.cn 删除。
本文参与“OSC源创计划”,欢迎正在阅读的你也加入,一块儿分享。

相关文章
相关标签/搜索