活动地址:CSDN21天学习挑战赛

目录

  • 前言
  • 了解Fashion-MNIST数据集
  • 下载数据集
    • 使用tensorflow下载(推荐)
    • 数据集分类
    • 数据集格式
  • 采用CPU训练还是GPU训练
    • 区别
    • 使用CPU训练
    • 使用GPU训练
  • 预处理
    • 最值归一化(normalization)
    • 升级图片维度
  • 显示部分图片
    • 建立CNN模型
      • 网络结构
      • 参数量
  • 训练模型
  • 模型评估

前言

关于环境这里不再赘述,与【深度学习】从LeNet-5识别手写数字入门深度学习一文的环境一致。

了解Fashion-MNIST数据集

Fashion-MNIST数据集与MNIST手写数字数据集不一样。但他们都有共同点就是都是灰度图片。
Fashion-MNIST数据集是各类的服装图片总共10类。下面列出了中英文对应表,方便接下来的学习。

中文 英文
t-shirt T恤
trouser 牛仔裤
pullover 套衫
dress 裙子
coat 外套
sandal 凉鞋
shirt 衬衫
sneaker 运动鞋
bag
ankle boot 短靴

下载数据集

使用tensorflow下载(推荐)

默认下载在C:\Users\用户\.keras\datasets路径下。

from tensorflow.keras import datasets# 下载数据集
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()

数据集分类

这里对从网上下载的数据集进行一个说明。

文件名 数据说明
train-images-idx3-ubyte 训练数据图片集
train-labels-idx1-ubyte 训练数据标签集
t10k-images-idx3-ubyte 测试数据图片集
t10k-labels-idx1-ubyte 测试数据标签集

数据集格式

训练数据集共60k张图片,各个服装类型的数据量一致也就是说每种6k。
测试数据集共10k张图片,各个服装类型的数据量一致也就是说每种100。

数据集均采用28281的灰度照片。

采用CPU训练还是GPU训练

一般来说有好的显卡(GPU)就使用GPU训练因为,那么对应的你就要下载tensorflow-gpu包。如果你的显卡较差或者没有足够资金入手一款好的显卡就可以使用CUP训练。

区别

(1)CPU主要用于串行运算;而GPU则是大规模并行运算。由于深度学习中样本量巨大,参数量也很大,所以GPU的作用就是加速网络运算。

(2)CPU计算神经网络也是可以的,算出来的神经网络放到实际应用中效果也很好,只不过速度会很慢罢了。而目前GPU运算主要集中在矩阵乘法和卷积上,其他的逻辑运算速度并没有CPU快。

使用CPU训练

# 使用cpu训练
import osos.environ["CUDA_VISIBLE_DEVICES"] = "-1"

使用CPU训练时不会显示CPU型号。

使用GPU训练

gpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0]  # 如果有多个GPU,仅使用第0个GPUtf.config.experimental.set_memory_growth(gpu0, True)  # 设置GPU显存用量按需使用tf.config.set_visible_devices([gpu0], "GPU")

使用GPU训练时会显示对应的GPU型号。

预处理

最值归一化(normalization)

关于归一化相关的介绍在前文中有相关介绍。 最值归一化与均值方差归一化

# 将像素的值标准化至0到1的区间内。train_images, test_images = train_images / 255.0, test_images / 255.0return train_images, test_images

升级图片维度

因为数据集是灰度照片,所以我们需要将[28,28]的数据格式转换为[28,28,1]

# 调整数据到我们需要的格式train_images = train_images.reshape((60000, 28, 28, 1))test_images = test_images.reshape((10000, 28, 28, 1))

显示部分图片

首先需要建立一个标签数组,然后绘制前20张,每行5个共四行
注意:如果你执行下面这段代码报这个错误:TypeError: Invalid shape (28, 28, 1) for image data。那么你就使用我下面注释掉的那句话。

from matplotlib import pyplot as pltclass_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat','Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']plt.figure(figsize=(20, 10))
for i in range(20):plt.subplot(4, 5, i + 1)plt.xticks([])plt.yticks([])plt.grid(False)plt.imshow(train_images[i], cmap=plt.cm.binary)#plt.imshow(train_images[i].squeeze(), cmap=plt.cm.binary)plt.xlabel(class_names[train_labels[i]])
plt.show()

绘制结果:

建立CNN模型

from tensorflow_core.python.keras import Input, Sequential
from tensorflow_core.python.keras.layers import Conv2D, Activation, MaxPooling2D, Flatten, Densedef simple_CNN(input_shape=(32, 32, 3), num_classes=10):# 构建一个空的网络模型,它是一个线性堆叠模型,各神经网络层会被顺序添加,专业名称为序贯模型或线性堆叠模型model = Sequential()# 卷积层1 model.add(Conv2D(filters=32, kernel_size=(3, 3), activation='relu', input_shape=input_shape))# 最大池化层1model.add(MaxPooling2D((2, 2), strides=(2, 2), padding='same'))# 卷积层2model.add(Conv2D(filters=64, kernel_size=(3, 3), padding='same', activation='relu'))# 最大池化层2model.add(MaxPooling2D((2, 2), strides=(2, 2), padding='same'))# 卷积层3model.add(Conv2D(filters=64, kernel_size=(3, 3), padding='same', activation='relu'))# flatten层常用来将输入“压平”,即把多维的输入一维化,常用在从卷积层到全连接层的过渡。model.add(Flatten())# 全连接层 对特征进行提取model.add(Dense(units=64, activation='relu'))# 输出层model.add(Dense(10))return model

网络结构

包含输入层的话总共9层。其中有三个卷积层,俩个最大池化层,一个flatten层,俩个全连接层。

参数量

总共参数为319k,训练时间比LeNet-5较长。建议采用GPU训练。

Total params: 257,162
Trainable params: 257,162
Non-trainable params: 0

训练模型

训练模型,进行10轮,将模型保存到1.h5文件中。后期可以直接加载模型继续训练。

from tensorflow_core.python.keras.models import load_model
from Cnn import simple_CNN
import tensorflow as tfmodel = simple_CNN(train_images, train_labels)
model.summary()  # 打印网络结构model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
model.save("1.h5")
history = model.fit(train_images, train_labels, epochs=10, validation_data=(test_images, test_labels))

训练结果:测试集acc为91.64%。从效果来说该模型还是不错的。

模型评估

对训练完模型的数据制作成曲线表,方便之后对模型的优化,看是过拟合还是欠拟合还是需要扩充数据等等。

acc = history.history['accuracy']val_acc = history.history['val_accuracy']loss = history.history['loss']val_loss = history.history['val_loss']epochs_range = range(10)plt.figure(figsize=(12, 4))plt.subplot(1, 2, 1)plt.plot(epochs_range, acc, label='Training Accuracy')plt.plot(epochs_range, val_acc, label='Validation Accuracy')plt.legend(loc='lower right')plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)plt.plot(epochs_range, loss, label='Training Loss')plt.plot(epochs_range, val_loss, label='Validation Loss')plt.legend(loc='upper right')plt.title('Training and Validation Loss')plt.show()

运行结果:

【深度学习】基于tensorflow的服装图像分类训练(数据集:Fashion-MNIST)相关推荐

  1. 《深度学习之TensorFlow》reading notes(3)—— MNIST手写数字识别之二

    文章目录 模型保存 模型读取 测试模型 搭建测试模型 使用模型 模型可视化 本文是在上一篇文章 <深度学习之TensorFlow>reading notes(2)-- MNIST手写数字识 ...

  2. [深度学习]-基于tensorflow的CNN和RNN-LSTM文本情感分析对比

    基于tensorflow的CNN和LSTM文本情感分析对比 1. 背景介绍 2. 数据集介绍 2.0 wordsList.npy 2.1 wordVectors.npy 2.2 idsMatrix.n ...

  3. python做神经网络有什么框架_神经网络与深度学习——基于TensorFlow框架和Python技术实现...

    目 录 第1章 绪论1 1.1 人工智能2 1.2 机器学习3 1.2.1 监督学习3 1.2.2 非监督学习3 1.2.3 半监督学习4 1.3 深度学习4 1.3.1 卷积神经网络4 1.3.2 ...

  4. 深度学习框架tensorflow二实战(训练一个简单二分类模型)

    导入工具包 import os import warnings warnings.filterwarnings("ignore") import tensorflow as tf ...

  5. 【深度学习】使用transformer进行图像分类

    文章目录 1.导入模型 2.定义加载函数 3.定义批量加载函数 4.加载数据 5.定义数据预处理及训练模型的一些超参数 6.定义数据增强模型 7.构建模型 7.1 构建多层感知器(MLP) 7.2 创 ...

  6. 在浏览器中进行深度学习:TensorFlow.js (十二)异常检测算法

    2019独角兽企业重金招聘Python工程师标准>>> 异常检测是机器学习领域常见的应用场景,例如金融领域里的信用卡欺诈,企业安全领域里的非法入侵,IT运维里预测设备的维护时间点等. ...

  7. 深度学习与TensorFlow

    深度学习与TensorFlow DNN(深度神经网络算法)现在是AI社区的流行词.最近,DNN 在许多数据科学竞赛/Kaggle 竞赛中获得了多次冠军. 自从 1962 年 Rosenblat 提出感 ...

  8. 深度学习必备书籍——《Python深度学习 基于Pytorch》

    作为一名机器学习|深度学习的博主,想和大家分享几本深度学习的书籍,让大家更快的入手深度学习,成为AI达人!今天给大家介绍的是:<Python深度学习 基于Pytorch> 文章目录 一.背 ...

  9. 4.1 深度学习框架-TensorFlow

    4.1 深度学习框架-TensorFlow 学习目标 目标 了解Tensorflow框架的组成.接口 了解TensorFlow框架的安装 知道tf.keras的特点和使用 应用 无 4.1.1 常见深 ...

最新文章

  1. C++ JsonCpp 使用(含源码下载)
  2. 浅谈ASP.NET中render方法
  3. 用Red5搭建支持WEB播放的实时监控视频
  4. Java正则表达式代码案例
  5. 全民超神VS王者荣耀:从角色养成到账户养成
  6. linux文件大小和目录,查看Linux目录和文件大小
  7. Java学习日报—注解、Hash、Lombok—2021/12/02
  8. Spring Boot(7)---构建系统和依赖管理
  9. db2数据库服务器时间怎么修改,DB2数据库中,肿么修改数据的创建时间,求SQL语句。...
  10. vscode 连接服务器jupyter_VScode中使用jupyter notebook
  11. ZooKeeper Web UI -- Shovel
  12. 测试人员的工作及介绍
  13. IDEA去掉SQL语句的黄色警告
  14. MySQL 2003报错解决方案
  15. php 磁盘配额,samba服务器安装+磁盘配额笔记
  16. volatile原理:happen before
  17. keil5中输入中文并且美化字体
  18. 基因测序的云计算平台可能带来的变革与进步
  19. 京东“鲸置”,“鲸吞”闲鱼?
  20. 网络设计与网络设备配置,网络设计需要哪些设备

热门文章

  1. svg中text标签字体、颜色、样式、大小、居中、旋转、垂直、text长度、tspan、textPath详解
  2. android 面试知识点
  3. 市场调研-全球与中国视频信号指示单元市场现状及未来发展趋势
  4. 16口工业级HDMI KVM切换器(MT-2116HL)
  5. UE 材质一 : 材质通道
  6. OpenLayers实例-Advanced Mapbox Vector Tiles-高级Mapbox矢量贴图
  7. 几个手机兼职做任务发布悬赏的app对比
  8. Android手机使用风灵网络优化软件设置教程
  9. 网易android开发工程师笔试心得
  10. ajax异步实现表单的无刷新验证