第10章 项目:多类花朵分类

本章我们使用Keras为多类分类开发并验证一个神经网络。本章包括:

  • 将CSV导入Keras
  • 为Keras预处理数据
  • 使用scikit-learn验证Keras模型

我们开始吧。

10.1 鸢尾花分类数据集

本章我们使用经典的鸢尾花数据集。这个数据集已经被充分研究过,4个输入变量都是数字,量纲都是厘米。每个数据代表花朵的不同参数,输出是分类结果。数据的属性是(厘米):

  1. 萼片长度
  2. 萼片宽度
  3. 花瓣长度
  4. 花瓣宽度
  5. 类别

这个问题是多类分类的:有两种以上的类别需要预测,确切的说,3种。这种问题需要对神经网络做出特殊调整。数据有150条:前5行是:

5.1,3.5,1.4,0.2,Iris-setosa
4.9,3.0,1.4,0.2,Iris-setosa
4.7,3.2,1.3,0.2,Iris-setosa
4.6,3.1,1.5,0.2,Iris-setosa
5.0,3.6,1.4,0.2,Iris-setosa

鸢尾花数据集已经被充分研究,模型的准确率可以达到95%到97%,作为目标很不错。本书的data目录下附带了示例代码和数据,也可以从UCI机器学习网站下载,重命名为iris.csv。数据集的详情请在UCI机器学习网站查询。

10.2 导入库和函数

我们导入所需要的库和函数,包括深度学习包Keras、数据处理包pandas和模型测试包scikit-learn。

import numpy
import pandas
from keras.models import Sequential
from keras.layers import Dense
from keras.wrappers.scikit_learn import KerasClassifier
from keras.utils import np_utils
from sklearn.cross_validation import cross_val_score
from sklearn.cross_validation import KFold
from sklearn.preprocessing import LabelEncoder
from sklearn.pipeline import Pipeline

10.3 指定随机数种子

我们指定一个随机数种子,这样重复运行的结果会一致,以便复现随机梯度下降的结果:

# fix random seed for reproducibility
seed = 7
numpy.random.seed(seed)

10.4 导入数据

数据可以直接导入。因为数据包含字符,用pandas更容易。然后可以将数据的属性(列)分成输入变量(X)和输出变量(Y):

# load dataset
dataframe = pandas.read_csv("iris.csv", header=None)
dataset = dataframe.values
X = dataset[:,0:4].astype(float)
Y = dataset[:,4]

10.5 输出变量编码

数据的类型是字符串:在使用神经网络时应该将类别编码成矩阵,每行每列代表所属类别。可以使用独热编码,或者加入一列。这个数据中有3个类别:Iris-setosaIris-versicolorIris-virginica。如果数据是

Iris-setosa
Iris-versicolor
Iris-virginica

用独热编码可以编码成这种矩阵:

Iris-setosa, Iris-versicolor, Iris-virginica 1, 0, 0
0, 1, 0
0, 0, 1

scikit-learn的LabelEncoder可以将类别变成数字,然后用Keras的to_categorical()函数编码:

# encode class values as integers
encoder = LabelEncoder()
encoder.fit(Y)
encoded_Y = encoder.transform(Y)
# convert integers to dummy variables (i.e. one hot encoded)
dummy_y = np_utils.to_categorical(encoded_Y)

10.6 设计神经网络

Keras提供了KerasClassifier,可以将网络封装,在scikit-learn上用。KerasClassifier的初始化变量是模型名称,返回供训练的神经网络模型。

我们写一个函数,为鸢尾花分类问题创建一个神经网络:这个全连接网络只有1个带有4个神经元的隐层,和输入的变量数相同。为了效果,隐层使用整流函数作为激活函数。因为我们用了独热编码,网络的输出必须是3个变量,每个变量代表一种花,最大的变量代表预测种类。网络的结构是:

4个神经元 输入层 -> [4个神经元 隐层] -> 3个神经元 输出层

输出层的函数是S型函数,把可能性映射到概率的0到1。优化算法选择ADAM随机梯度下降,损失函数是对数函数,在Keras中叫categorical_crossentropy

# define baseline model
def baseline_model():# create modelmodel = Sequential()model.add(Dense(4, input_dim=4, init='normal', activation='relu')) model.add(Dense(3, init='normal', activation='sigmoid'))# Compile modelmodel.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) return model

可以用这个模型创建KerasClassifier,也可以传入其他参数,这些参数会传递到fit()函数中。我们将训练次数nb_epoch设成150,批尺寸batch_size设成5,verbose设成0以关闭调试信息:

estimator = KerasClassifier(build_fn=baseline_model, nb_epoch=200, batch_size=5, verbose=0)

10.7 用K折交叉检验测试模型

现在可以测试模型效果了。scikit-learn有很多种办法可以测试模型,其中最重要的就是K折检验。我们先设定模型的测试方法:K设为10(默认值很好),在分割前随机重排数据:

kfold = KFold(n=len(X), n_folds=10, shuffle=True, random_state=seed)

这样我们就可以在数据集(Xdummy_y)上用10折交叉检验(kfold)测试性能了。模型需要10秒钟就可以跑完,每次检验输出结果:

results = cross_val_score(estimator, X, dummy_y, cv=kfold)
print("Accuracy: %.2f%% (%.2f%%)" % (results.mean()*100, results.std()*100))

输出结果的均值和标准差,这样可以验证模型的预测能力,效果拔群:

Baseline: 95.33% (4.27%)

10.8 总结

本章关于使用Keras开发深度学习项目。总结一下:

  • 如何导入数据
  • 如何使用独热编码处理多类分类数据
  • 如何与scikit-learn一同使用Keras
  • 如何用Keras定义多类分类神经网络
  • 如何用scikit-learn通过K折交叉检验测试Keras的模型

第10章 项目:多类花朵分类相关推荐

  1. Python深度学习实战:多类花朵分类

    Python深度学习实战:多类花朵分类 鸢尾花分类数据集 导入库和函数 指定随机数种子 导入数据 输出变量编码 设计神经网络 用K折交叉检验测试模型 总结 本章我们使用Keras为多类分类开发并验证一 ...

  2. 《软件项目管理(第二版)》第 10 章——项目收尾 重点部分总结

    文章目录 前言 一.讨论 二.简答题 总结 前言 学习了项目的开发与发布之后,我们就可以单独对一个项目进行开发了,但是在企业中开发中,除了编码之外,还需要项目管理.团队协作开发等,这就是软件项目管理板 ...

  3. 第10章 对象和类 -1

    待定 本章内容:  过程性编程和面向对象编程  类概念  如何定义和实现类  公有类访问和私有类访问  类的数据成员  类方法(类函数成员)  创建和使用类对象  类的构造函数和析构函 ...

  4. 第10章 项目质量管理

    目录 10.1 项目质量管理概论 10.2 规划质量管理 10.3 实施质量保证 10.4 质量控制 补充 上午考3分(质量管理的3个过程的定义和工具技术),下午案例分析考的概率大 10.1 项目质量 ...

  5. PMP-【第10章 项目沟通管理】-2021-2-16(220页-231页)

    1.项目沟通的管理的本质 2.沟通遵循的5C原则 3.项目经理是组织专家做事的人,而不是自己亲自做事 4.沟通的种类 5.交互式沟通,推式沟通,拉式沟通 6.项目经理与职能经理/外界/管理层 如何打交 ...

  6. 在项目中谨慎为系统类添加分类!!!!!

    结论: 1.坚决杜绝为系统类做方法交换(见到[class_replaceMethod]格杀勿论!) 2.为系统类添加分类时候,属性和方法名必须加上[世上独一无二]的前缀,避免冲突和混淆. 之所以让我对 ...

  7. Java黑皮书课后题第10章:10.2(BMI类)将下面的新构造方法加入BMI类中

    Java黑皮书课后题第10章:10.2(BMI类)将下面的新构造方法加入BMI类中 题目 程序说明 题目槽点 代码:Test2_BMI.java 运行实例 题目 程序说明 Test2_BMI.java ...

  8. Java黑皮书课后题第10章:*10.1(Time类)设计一个名为Time的类。编写一个测试程序,创建两个Time对象(使用new Time()和new Time(555550000))

    Java黑皮书课后题第10章:*10.1设计一个名为Time的类.编写一个测试程序,创建两个Time对象 题目 程序 代码 Test1.java Test1_Time.java 运行结果 UML 题目 ...

  9. 【信息系统项目管理师】第10章 下篇-项目干系人管理 知识点详细整理

    个人资料,仅供学习使用 教程:信息系统项目管理师(第3版) 修改时间--2021年10月4日 09:19:27 参考资料: 信息系统项目管理师(第3版) 题目书(2021下半年)--马军 本文包括: ...

最新文章

  1. Android程序完全退出的三种方法
  2. 【Java小工匠聊密码学】-密码学--综述
  3. python使用matplotlib可视化、自定义设置Y轴刻度标签字体的大小( setting axis ticks size in matplotlib y axis)
  4. cbow 和skip-gram比较
  5. AlphaBlend 使用方法
  6. BZOJ 3489: A simple rmq problem(K-D Tree)
  7. 尝试连接到服务器时出错请检查虚拟机管理器,Hyper-V尝试连接到服务器出错无效类的解决方法...
  8. Android studio ERROR: Software caused connection abort: recv failed 解决方法
  9. SQL2005 学记笔记(9)
  10. 12-Java读写CSV格式文件(opencsv)
  11. Python的操作符?
  12. Windows安装Oracle与PlSql教程
  13. CS61A 学习笔记Week1
  14. web编程1–用户注册之文本框应用,coon连接,存入mysql
  15. 如何在Flatter中以正确的方式存储登录凭证
  16. 数据库问题——合并表格
  17. Android 4.4Phone的变化(二)
  18. windows11及以下系统怎么修改账户名
  19. echarts实现3D地图,轮播功能、背景图片、鼠标悬浮展示数据,附源码!
  20. 华为机试真题 Python 实现【模拟商场优惠打折】【2022.11 Q4 新题】

热门文章

  1. 中国象棋程序的设计与实现(九)–棋子点,棋子的小窝 中国象棋程序的设计与实现(八)-如何构造一个棋子(車馬炮等)...
  2. 斜方肌(01):负重耸肩
  3. 树莓派从零开始快速入门第5讲——点亮LED
  4. 奇怪问题--二极管并联电阻分压
  5. 这款神器让3000多个失踪孩子回家,原来科技可以如此有情有义!
  6. 大数据测试回放视频-小强测试内部学员技术分享
  7. 数电笔记总结(三)(逻辑门电路)
  8. 【Python 基础教程】Python语言的自我介绍
  9. 防爆定位信标与防爆定位基站有什么区别?
  10. Unity成亮:我们一直在和开发者共建一个开放共赢的平台