一、下载数据集并展示

CIFAR-10 是由 Hinton 的学生 Alex Krizhevsky 和 Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。一共包含 10 个类别的 RGB 彩色图 片:飞机( a叩lane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。图片的尺寸为 32×32 ,数据集中一共有 50000 张训练圄片和 10000 张测试图片。 CIFAR-10 的图片样例如图所示。

与 MNIST 数据集中目比, CIFAR-10 具有以下不同点:

• CIFAR-10 是 3 通道的彩色 RGB 图像,而 MNIST 是灰度图像。

• CIFAR-10 的图片尺寸为 32×32, 而 MNIST 的图片尺寸为 28×28,比 MNIST 稍大。

• 相比于手写字符, CIFAR-10 含有的是现实世界中真实的物体,不仅噪声很大,而且物体的比例、 特征都不尽相同,这为识别带来很大困难。

import matplotlib.pyplot as plt
import tensorflow as tf
from keras import datasets, layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
%config Completer.use_jedi = False(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data()class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck']plt.figure(figsize=(10, 10))
for i in range(10):plt.subplot(5, 5, i + 1)plt.xticks([])plt.yticks([])plt.grid(False)plt.imshow(train_images[i], cmap=plt.cm.binary)plt.xlabel(class_names[train_labels[i][0]])
plt.show()#查看图片信息
print('图片尺寸为:',train_images[0].shape)
print('训练集图片个数为:',len(train_images))
print('测试集图片个数为:',len(test_images))

如果使用代码下载失败,那么去到cifar10数据集下载地址:https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz,将下载后的文件存放在 ~./keras/datasets目录下,~表示当前用户路径。

二、构建模型

# 构造网络模型
model = models.Sequential([tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(32, 32, 3)),tf.keras.layers.MaxPooling2D(2,2),tf.keras.layers.Conv2D(64, (3,3), activation='relu'),tf.keras.layers.MaxPooling2D(2,2),tf.keras.layers.Conv2D(128, (3,3), activation='relu'),tf.keras.layers.MaxPooling2D(2,2),#转换为一维tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dropout(0.5),tf.keras.layers.Dense(10, activation='softmax'),
])# 查看网络结构
model.summary()

三、定义损失函数优化器

model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])

四、数据增强

注意此处只对训练数据集做随机翻转、随机裁剪、平移等,测试集只需归一化。

train_image = ImageDataGenerator(rescale=1/255,#随机翻转rotation_range=40,#平移width_shift_range=0.2,height_shift_range=0.2,#随机裁剪shear_range=0.2,#随机缩放zoom_range=0.2,horizontal_flip=True,fill_mode='nearest'
)test_image = ImageDataGenerator(rescale=1/255,
)

五、模型训练

history = model.fit(train_images, train_labels, epochs=20,validation_data=(test_images, test_labels))

六、绘制acc

# 测试模型并绘制loss图(history的使用)
plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label='val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0.0, 1.0])
plt.legend(loc='lower right')
plt.show()

注:此处只做流程演示并未调整参数,可以自行优化。

Tensorflow2+训练CIFAR10相关推荐

  1. [深度学习-实践]Tensorflow 2.x应用ResNet SeNet网络训练cifar10数据集的模型在测试集上准确率 86%-87%,含完整代码

    环境 tensorflow 2.1 最好用GPU Cifar10数据集 CIFAR-10 数据集的分类是机器学习中一个公开的基准测试问题.任务的目标对一组32x32 RGB的图像进行分类,这个数据集涵 ...

  2. 深度学习训练的时候gpu占用0_26秒单GPU训练CIFAR10,Jeff Dean也点赞的深度学习优化技巧...

    选自myrtle.ai 机器之心编译机器之心编辑部 26 秒内用 ResNet 训练 CIFAR10?一块 GPU 也能这么干.近日,myrtle.ai 科学家 David Page 提出了一大堆针对 ...

  3. 【深度学习】训练CIFAR-10数据集实现分类加测试

    网上有很多博主写的训练CIFAR-10的代码,本次只是单纯记录一下自己调试的一个程序,对于初学深度学习的小白可以参考,如有不对,请多多见谅!!! 一.CIFAR-10数据集由10个类的60000个32 ...

  4. 图解半监督学习FixMatch,只用10张标注图片训练CIFAR10

    2020-05-25 11:20:08 作者:amitness 编译:ronghuaiyang 导读 仅使用10张带有标签的图像,它在CIFAR-10上的中位精度为78%,最大精度为84%,来看看是怎 ...

  5. tensorflow2 训练和预测使用不同的输出层、获取权重参数

    目标: youtubeNet通过训练tensorflow2时设置不同的激活函数,训练和预测采用不同的分支,然后可以在训练和测试时,把模型进行分离,得到训练和预测时,某些层的参数不同.可以通过类似迁移学 ...

  6. 使用caffe自带模型训练cifar10数据集

      前面训练了mnist数据集!但caffe自带的数据集还有cifar10数据集.同样cifar10数据集也是分类数据集,共分10类.cifar10数据集中包含60000张32x32的彩色图片.(其中 ...

  7. matlab训练cifar10,认识CIFAR-10数据集

    CIFAR-10是一个广泛使用的标准数据集,里面包含了各种阿猫阿狗阿汽车--为了在后续学习实验中用好它,首先需要认识了解一下. 把tensorflow官方model下的cifar10文件复制到工作区, ...

  8. 深度学习:使用pytorch训练cifar10数据集(基于Lenet网络)

    文档基于b站视频:https://www.bilibili.com/video/BV187411T7Ye 流程 model.py --定义LeNet网络模型 train.py --加载数据集并训练,训 ...

  9. 现代卷积神经网络(NiN),并使用NIN训练CIFAR10的分类

    专栏:神经网络复现目录 本章介绍的是现代神经网络的结构和复现,包括深度卷积神经网络(AlexNet),VGG,NiN,GoogleNet,残差网络(ResNet),稠密连接网络(DenseNet). ...

最新文章

  1. WPF中获取鼠标相对于桌面位置
  2. heroes 2 android,英雄出击2游戏下载-英雄出击2Heroes Strike2中文安卓版下载v0.0.5- 游侠下载站...
  3. c语言 switch语句大小,C语言switch语句(板式整齐)
  4. mysql audit log.so_Percona Audit Log Plugin(mysql 审计)
  5. php怎么克隆,利用php怎么对对象进行克隆
  6. vue 响应式ui_如何在Vue.js中设置响应式UI搜索
  7. 大数据相关端口号(hive hdfs spark)
  8. Codeforces 490F Treeland Tour(离散化 + 线段树合并)
  9. PHP+crontab 完美实现定时任务
  10. 修改XAMPP端口(2)
  11. siteservercms 缺点_SiteServer CMS 术语大全
  12. DTcms Core项目发布到IIS教程
  13. [Java web编程]第2章 HTML与css网页开发基础(动画)
  14. 【点云处理技术之PCL】点云配准算法之NDT
  15. 使用Hooks实现防抖节流 TS版本
  16. bin文件用cad打开_bin文件怎么用cad打开
  17. Ajax-服务器响应数据详解
  18. 微信小程序--云开发数据库操作之where()
  19. 罗斯柴尔德家族:“大道无形”世界首富
  20. oracle 如何备份.bak,Oracle备份如何到异机还原

热门文章

  1. 关于ContactsContract
  2. 阿里面试题之洗衣机问题
  3. iPad越狱是什么?iPad越狱有什么好处和坏处
  4. linux can测试程序,Linux CAN Shell 测试脚本程序
  5. 团队任务管理该怎么做才高效?管理者应该学会这些
  6. Arduino与Proteus仿真实例-AT24C256串行(I2C)EEPROM数据存取驱动仿真
  7. 同余方程、欧拉定理、乘法逆元、定义在Zm上的矩阵求逆
  8. 云计算演义(7)中国云计算离世界有多远?
  9. 【Visual C++】游戏开发笔记二十七 Direct3D 11入门级知识介绍
  10. 信息系统项目管理师论文范例5:成本管理