判断模型是过拟合仍是欠拟合--学习曲线

转自 :http://blog.csdn.net/aliceyangxi1987/article/details/73598857

学习曲线是什么?

学习曲线就是经过画出不一样训练集大小时训练集和交叉验证的准确率,能够看到模型在新数据上的表现,进而来判断模型是否方差偏高或误差太高,以及增大训练集是否能够减少过拟合。python


怎么解读?

当训练集和测试集的偏差收敛但却很高时,为高误差。 
左上角的误差很高,训练集和验证集的准确率都很低,极可能是欠拟合。 
咱们能够增长模型参数,好比,构建更多的特征,减少正则项。 
此时经过增长数据量是不起做用的。git

当训练集和测试集的偏差之间有大的差距时,为高方差。 
当训练集的准确率比其余独立数据集上的测试结果的准确率要高时,通常都是过拟合。 
右上角方差很高,训练集和验证集的准确率相差太多,应该是过拟合。 
咱们能够增大训练集,下降模型复杂度,增大正则项,或者经过特征选择减小特征数。dom

理想状况是是找到误差和方差都很小的状况,即收敛且偏差较小。学习


怎么画?

在画学习曲线时,横轴为训练样本的数量,纵轴为准确率。测试

例如一样的问题,左图为咱们用 naive Bayes 分类器时,效果不太好,分数大约收敛在 0.85,此时增长数据对效果没有帮助。spa

右图为 SVM(RBF kernel),训练集的准确率很高,验证集的也随着数据量增长而增长,不过由于训练集的仍是高于验证集的,有点过拟合,因此仍是须要增长数据量,这时增长数据会对效果有帮助。.net


上图的代码以下:

模型这里用 GaussianNB 和 SVC 作比较, 
模型选择方法中须要用到 learning_curve 和交叉验证方法 ShuffleSplit。code

import numpy as np import matplotlib.pyplot as plt from sklearn.naive_bayes import GaussianNB from sklearn.svm import SVC from sklearn.datasets import load_digits from sklearn.model_selection import learning_curve from sklearn.model_selection import ShuffleSplit

首先定义画出学习曲线的方法, 
核心就是调用了 sklearn.model_selection 的 learning_curve, 
学习曲线返回的是 train_sizes, train_scores, test_scores, 
画训练集的曲线时,横轴为 train_sizes, 纵轴为 train_scores_mean, 
画测试集的曲线时,横轴为 train_sizes, 纵轴为 test_scores_mean:blog

def plot_learning_curve(estimator, title, X, y, ylim=None, cv=None, n_jobs=1, train_sizes=np.linspace(.1, 1.0, 5)): ~~~ train_sizes, train_scores, test_scores = learning_curve( estimator, X, y, cv=cv, n_jobs=n_jobs, train_sizes=train_sizes) train_scores_mean = np.mean(train_scores, axis=1) test_scores_mean = np.mean(test_scores, axis=1) ~~~ 

在调用 plot_learning_curve 时,首先定义交叉验证 cv 和学习模型 estimator。get

这里交叉验证用的是 ShuffleSplit, 它首先将样例打散,并随机取 20% 的数据做为测试集,这样取出 100 次,最后返回的是 train_index, test_index,就知道哪些数据是 train,哪些数据是 test。

estimator 用的是 GaussianNB,对应左图:

cv = ShuffleSplit(n_splits=100, test_size=0.2, random_state=0) estimator = GaussianNB() plot_learning_curve(estimator, title, X, y, ylim=(0.7, 1.01), cv=cv, n_jobs=4)

再看 estimator 是 SVC 的时候,对应右图:

cv = ShuffleSplit(n_splits=10, test_size=0.2, random_state=0) estimator = SVC(gamma=0.001) plot_learning_curve(estimator, title, X, y, (0.7, 1.01), cv=cv, n_jobs=4)
相关文章
相关标签/搜索