Sklearn之datasets和训练

数据集的操做

以iris数据集为例,首先导入数据集python

iris = datasets.load_iris()

数据集是一个类词典的数据,其属性有git

data 数据集,类型是numpy的ndarray
target 数据对应的类标记,类型是一维的ndarray
target_name 类标记对应的名字,类型是一维的ndarray
DESCR 数据集的描述信息

拆分训练集和测试集

 

训练estimator

经过fit(X,y)和predict(X)分别进行训练和预测,这里以SVM为例训练一个estimator并预测测试

from sklearn import datasets
from sklearn import svm
from matplotlib import pyplot as plt
import numpy as np


digits = datasets.load_digits()
clf = svm.SVC(gamma=0.001, C=100.)
clf.fit(digits.data[:-1],digits.target[:-1])

plt.figure(1, figsize=(3, 3))
plt.imshow(digits.images[-1], cmap=plt.cm.gray_r, interpolation='nearest')
plt.show()

print(clf.predict(digits.data[-1:]))

保存模型

经过sklearn.externals下的joblib.dump()来保存训练好的模型,再次使用能够经过load进行导入spa

from sklearn import datasets
from sklearn import svm
from sklearn.externals import joblib

iris = datasets.load_iris()
clf = svm.SVC(gamma=0.001, C=100.)
clf.fit(iris.data[:-2], iris.target[:-2])
print('倒数第一条数据的类标记为', iris.target[-1], ', 预测结果为 ',
      clf.predict(iris.data[-1:]))


joblib.dump(clf, 'model.pkl')
print('倒数第二条数据的类标记为', iris.target[-2], ', 预测的结果为 ',
      joblib.load('model.pkl').predict(iris.data[-2:-1]))

相关文章
相关标签/搜索