纯Python实现鸢尾属植物数据集神经网络模型

摘要: 本文以Python代码完成整个鸾尾花图像分类任务,没有调用任何的数据包,适合新手阅读理解,并动手实践体验下机器学习方法的大体流程。

尝试使用过各大公司推出的植物识别APP吗?好比微软识花、花伴侣等这些APP。当你看到一朵不知道学名的花时,只须要打开植物识别APP,拍摄一张你所想辨认的植物照片并上传,APP会自动识别出该花的品种及详细介绍,感受手机中装了一个知识渊博的生物学家,是否是很神奇?其实,背后的原理很简单,是一个图像分类的过程,将上传的图像与手机中预存的数据集或联网数据进行匹配,将其分类到对应的类别便可。随着深度学习方法的应用,图像分类的精度愈来愈高,在部分数据集上已经超越了人眼的能力。python

相对于传统神经网络的方法而言,深度学习方法通常对数据集规模、硬件平台有着比较高的要求,若是只是单纯的想尝试了解图像分类任务的基本流程,建议采用小数据集样本及传统的神经网络方法实现。本文将带领读者采用鸢尾属植物数据集(Iris Data Set)来实现一个分类任务,整个鸢尾属植物数据集是机器学习中历史悠久的数据集,比如今经常使用的数字手写体数据集(Mnist Data Set)数据集还要早得多,该数据集来源于英国著名的统计学家、生物学家Ronald Fiser。本文在不使用相关软件库的状况下,从头开始构建针对鸢尾属植物数据的神经网络模型,对其进行训练并得到好的结果。算法

clipboard.png

鸢尾属植物数据集是用于测试机器学习算法的最经常使用数据集。该数据包含四种特征,萼片长度、萼片宽度、花瓣长度和花瓣宽度,用于鸢尾属植物的不一样物种(versicolor, virginica和setosa)。此外,每一个物种有50个实例(数据行),下面让咱们看看样本数据分布状况。数组

clipboard.png

咱们将在这个数据集上使用神经网络构建分类模型。为了简单起见,使用花瓣长度和花瓣宽度做为特征,且只有两类物种:versicolor和virginica。下面就让咱们在Python中逐步训练针对该样本数据集的神经网络:网络

步骤1:准备鸢尾属植物数据集

将Iris数据集导入python并对数据进行子集划分以保留行之间的相关性:机器学习

clipboard.png

clipboard.png

蓝色点表明Versicolor物种,红色点表明Virginica物种。本文构建的神经网络将在这些数据上进行训练,以期最后能正确地分类物种。函数

步骤2:初始化参数(权重和偏置)

下面构建一个具备单个隐藏层的神经网络。此外,将隐藏图层的大小设置为6:学习

clipboard.png

步骤3:前向传播(forward propagation)

在前向传播过程当中,使用tanh激活函数做为第一层的激活函数,使用sigmoid激活函数做为第二层的激活函数:测试

clipboard.png

步骤4:计算代价函数(cost function)

目标是使得计算的代价函数小化,本文采用交叉熵(cross-entropy)做为代价函数:spa

clipboard.png

步骤5:反向传播(back propagation)

计算反向传播过程,主要是计算代价函数的导数:设计

clipboard.png

步骤6:更新参数

使用反向传播过程当中计算的梯度来更新权重和偏置:

clipboard.png

步骤7:创建神经网络

将以上全部函数组合起来以建立设计的神经网络模型。总而言之,下面是模型函数的总体顺序:

一、初始化参数

二、前向传播

三、计算代价函数

四、反向传播

五、更新参数

clipboard.png

步骤8:跑动模型

将隐藏层节点设置为6,最大迭代次数设置为10,000次,并每隔1000次打印出训练的结果:

clipboard.png

clipboard.png

步骤9:画出分类边界

clipboard.png

clipboard.png

从图中能够观察到,只有四个点被错误分类。虽然咱们能够调整模型来进一步地提升模型训练精度,但该些操做显然会致使过拟合现象的出现。

资源

本文做者:【方向】

阅读原文

本文为云栖社区原创内容,未经容许不得转载。

相关文章
相关标签/搜索