文章目录

  • 获取数据
  • 图片预处理
  • 构建神经网络
  • 输入神经网络
  • 模型参数的优化

该项目旨在教会机器识别马和人的图像


获取数据

训练数据集:horse-or-human
测试数据集:validation-horse-or-human

图片预处理

当图片数据过大,且图片的尺寸不一致时,就需要对图片进行预处理操作,将其裁剪成规定大小的图片,然后再生成器中要指定每个批次中要训练的图片的数量。代码如下:

from tensorflow.keras.preprocessing.image import ImageDataGenerator#创建两个数据生成器,指定scaling范围0-1
train_datagen = ImageDataGenerator(rescale=1/255)
validation_datagen = ImageDataGenerator(rescale=1/255)#指定训练数据文件夹
train_generator = train_datagen.flow_from_directory('F:\\ML_Data\\horse-or-human\\train', #训练数据所在文件夹target_size=(300,300),       #指定输出尺寸batch_size=32,               #单次传递给程序用以训练的参数个数class_mode='binary')         #指定二分类#指定测试数据文件夹
validation_generator = validation_datagen.flow_from_directory('F:\\ML_Data\\horse-or-human\\validation', #训练数据所在文件夹target_size=(300,300),       #指定输出尺寸batch_size=32,               #单次传递给程序用以训练的参数个数class_mode='binary')         #指定二分类

构建神经网络

import tensorflow as tf
from tensorflow.keras.optimizers import RMSprop
model = tf.keras.models.Sequential([#第一个卷积tf.keras.layers.Conv2D(16,(3,3),activation='relu',input_shape=(300,300,3)), #300x300和三字节颜色tf.keras.layers.MaxPooling2D(2,2),#第二个卷积tf.keras.layers.Conv2D(32,(3,3),activation='relu'), tf.keras.layers.MaxPooling2D(2,2),#第三个卷积tf.keras.layers.Conv2D(64,(3,3),activation='relu'), tf.keras.layers.MaxPooling2D(2,2),#展开输入神经网络tf.keras.layers.Flatten(),tf.keras.layers.Dense(512,activation='relu'), #512个神经元tf.keras.layers.Dense(1,activation='sigmoid') #1个神经元输出,0代表马,1代表人
])
model.compile(loss='binary_crossentropy',optimizer=RMSprop(lr=0.001),metrics=['acc'])

输入神经网络

history = model.fit(train_generator, #输入数据epochs = 15,     #训练轮数verbose = 1,     #日志显示,0为不在标准输出流输出日志信息,1为输出进度条记录,2为每个epoch输出一行记录validation_data = validation_generator, #指定的验证集validation_steps = 8 #表示将一个epoch的训练集数据分为多少个batch
)

运行结果:

模型参数的优化

这里相信大家跟我都有一个问题,我们怎么知道设定多少个过滤器、卷积层重复循环多少遍、定义多少个神经元?如何确定这些参数才能达到最好的效果。

这里我们可以调用Hyperband库和HyperParameters库来让机器自己判断如何定义这些参数合适。其实就是在一个区间内循环跑好几遍,选结果最优的参数然后将其保存下来。

# 人马识别项目参数优化
from tensorflow.keras.preprocessing.image import ImageDataGenerator#创建两个数据生成器,指定scaling范围0-1
train_datagen = ImageDataGenerator(rescale=1/255)
validation_datagen = ImageDataGenerator(rescale=1/255)#指定训练数据文件夹
train_generator = train_datagen.flow_from_directory('F:\\ML_Data\\horse-or-human\\train', #训练数据所在文件夹target_size=(150,150),       #指定输出尺寸batch_size=32,               #单次传递给程序用以训练的参数个数class_mode='binary')         #指定二分类#指定测试数据文件夹
validation_generator = validation_datagen.flow_from_directory('F:\\ML_Data\\horse-or-human\\validation', #训练数据所在文件夹target_size=(150,150),       #指定输出尺寸batch_size=32,               #单次传递给程序用以训练的参数个数class_mode='binary')         #指定二分类import tensorflow as tf
from tensorflow.keras.optimizers import RMSprop
#调用调参库
from kerastuner.tuners import Hyperband
from kerastuner.engine.hyperparameters import HyperParametershp = HyperParameters()
def build_model(hp):model = tf.keras.models.Sequential()# 寻找最优过滤器个数model.add(tf.keras.layers.Conv2D(hp.Choice('num_filters_layer0',values=[16,64],default=16),(3,3),activation='relu',input_shape=(150,150,3)))model.add(tf.keras.layers.MaxPooling2D(2,2))# 寻找最优卷积层处理遍数for i in range (hp.Int("num_conv_layers",1,3)):# 寻找每一层最优过滤器个数model.add(tf.keras.layers.Conv2D(hp.Choice(f'num_filters_layer{i+1}',values=[16,64],default=16),(3,3),activation='relu'))model.add(tf.keras.layers.MaxPooling2D(2,2))model.add(tf.keras.layers.Flatten())model.add(tf.keras.layers.Dense(hp.Int("hidden_units",128,512,step=32),activation='relu')) # 在[128,512]区间中每次递增32,寻找最优神经元个数model.add(tf.keras.layers.Dense(1,activation='sigmoid'))model.compile(loss='binary_crossentropy',optimizer=RMSprop(lr=0.001),metrics=['acc'])return modeltuner = Hyperband(build_model,objective='val_acc',max_epochs=15,directory='horse_human_params',hyperparameters=hp,project_name='my_horse_human_project'
)
tuner.search(train_generator,epochs=10,validation_data=validation_generator)
#model.fit(train_generator,epochs=10,validation_data=validation_generator)
best_hps = tuner.get_best_hyperparameters(1)[0]
print(best_hps.values) #查看最优参数

运行结果:

由此可见,第一个卷积层设置16个过滤器,之后再重复3次。每次过滤器个数分别为16,64,64.神经元个数为384。但缺点是运行时间太长了,这里花了59m19s才跑出循环。

呜呜呜,运行上面结果时f’num_filters_layer{i+1}'忘了+1。这导致两个layer0重合了,所以在运行结果里大家只看见三层layer,最开始那层被覆盖了。跑一次太久了运行结果图就不做更正了。

我们运用算出的最优参数构建模型看看。

model = tuner.hypermodel.build(best_hps) # 按照算出的最优参数构建模型
model.summary() # 查看模型详情

运行结果:

机器学习 —— 人马图像分类相关推荐

  1. [Python人工智能] 十.Tensorflow+Opencv实现CNN自定义图像分类案例及与机器学习KNN图像分类算法对比

    从本专栏开始,作者正式开始研究Python深度学习.神经网络及人工智能相关知识.前一篇详细讲解了gensim词向量Word2Vec安装.基础用法,并实现<庆余年>中文短文本相似度计算及多个 ...

  2. 机器学习实现图像分类(简单易上手) SVM KNN 决策树 朴素贝叶斯 机器学习作业

    机器学习实现图像分类 SVM KNN 决策树 朴素贝叶斯 重要提示:本文仅仅靠调用python的sklearn中的模型包实现机器学习方法,不喜勿喷 代码主要参考并改进 https://blog.csd ...

  3. TensorFlow(9)(项目)人马图像分类(卷积神经网络)

    目录 基础理论 1.sigmoid激活函数 2.聚类&&分类 一.准备数据 1.创建两个数据生成器 2.创建训练数据与测试数据生成器 训练数据生成器 测试数据生成器 二.构建神经网络 ...

  4. 机器学习花朵图像分类_在PyTorch中使用转移学习进行图像分类

    想了解更多好玩的人工智能应用,请关注公众号"机器AI学习 数据AI挖掘","智能应用"菜单中包括:颜值检测.植物花卉识别.文字识别.人脸美妆等有趣的智能应用.. ...

  5. 【机器学习】Tensorflow.js:在浏览器中使用机器学习实现图像分类

    ⭐️ 本文首发自 前端修罗场(点击加入),是一个由资深开发者独立运行的专业技术社区,我专注 Web 技术.答疑解惑.面试辅导以及职业发展.现在加入,私聊我即可获取一次免费的模拟面试机会,帮你评估知识点 ...

  6. 最新综述:图像分类中的对抗机器学习

    ©PaperWeekly 原创 · 作者|孙裕道 学校|北京邮电大学博士生 研究方向|GAN图像生成.人脸对抗样本生成 论文标题: Adversarial Machine Learning in Im ...

  7. 赠书 | 图像分类问题建模方案探索实践

    作者 | 中国农业银行 陆春晖 责编 | 晋兆雨 出品 | AI科技大本营 头图 | 付费下载于视觉中国 *文末有赠书福利 背景 图像分类,是计算机视觉领域的一个核心问题,顾名思义就是输入一张图像,根 ...

  8. 图像分类:13个Kaggle项目的经验总结

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 方向:图像分类,编辑:数据派THU 本文约2800字,建议阅读9分钟 本 ...

  9. 图像分类:来自13个Kaggle项目的经验总结

    来源:机器学习实验室 本文约2800字,建议阅读9分钟 本文作者与你分享图像分类项目经验总结. 任何领域的成功都可以归结为一套小规则和基本原则,当它们结合在一起时会产生伟大的结果. 机器学习和图像分类 ...

最新文章

  1. Linux网络编程基础(一)
  2. js之数据类型及类型转换
  3. python下载安装搭建
  4. python爬取网站数据步骤_python怎么爬取数据
  5. SPA 单页Web应用
  6. android反射开启通知_作为Android开发者 你真的知道app从启动到主页显示的过程吗?...
  7. 双亲委派机制_史上三次破坏ClassLoader双亲委派机制
  8. javaSE----for,wile ,do while循环的应用
  9. linux学习作业-第八周
  10. PHP打印Excel表格并下载
  11. 联邦学习后门攻击代码阅读——backdoors101
  12. android支付后声音,支付宝到账声音生成器
  13. 等差乘等比数列求和公式
  14. 奇怪的吃播_快来围观那些奇怪的吃播!!
  15. IP网络摄像机安装注意事项
  16. 服务器修改host的ip,主机IP地址设置
  17. 网络安全篇 浅谈学习网络安全的看法-00
  18. html中只显示农历的完整代码,很全的显示阴历(农历)日期的js代码
  19. 【Struts2】一_idea快速搭建struts2框架
  20. NetLogo 初步认识

热门文章

  1. 浅谈新手入行前端自学到什么程度才能找工作?
  2. OA系统----考勤管理----JDBC,Ajax
  3. Pytorch复现STGCN:基于图卷积时空神经网络在交通速度中的预测
  4. **java 发送邮件**
  5. 物联网毕设 人体定位智能调速风扇系统
  6. 蘑菇街、滴滴、淘宝、微信的组件化架构解析,附Demo和PDF
  7. 学习笔记-Flutter 动画详解(一)
  8. 聊一聊工业和自动化之间的5种接近传感器
  9. 温故而知新--冒个泡、排个序
  10. 小编必看,教你如何使用微信公众号编辑器快速排版精美文章