实验目的

通过 Tensorflow 的基础类,构建卷积神经网络,用于花朵图片的分类。

实验环境

import tensorflow as tfprint(tf.__version__)

output:
2.3.0

实验步骤

(一) 数据获取和预处理

1.1 数据选择 TensorFlow 官方提供的花朵图片数据,经如下代码获取:

dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
img_dir= tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)

1.2 读取图片:这里,我们通过 tf.keras.preprocessing.image_dataset_from_directory 函数批量读入图片。

import pathlib
# 数据保存路径
data_dir = pathlib.Path(data_dir)BATCH_SIZE = 32  # BATCH size 设为32
img_height = 180  # 读取图片后,高度转换为180像素
img_width = 180  # 读取图片后,宽度转换为180像素# 读入images (training data)
train_ds = tf.keras.preprocessing.image_dataset_from_directory(img_dir,shuffle=True, validation_split=0.2, seed=123, subset='training', batch_size=BATCH_SIZE,image_size=(img_height, img_width))# 读入images(test data)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(img_dir, shuffle=True,validation_split=0.2,seed=123,subset='validation',image_size=(img_height, img_width),batch_size=BATCH_SIZE)

1.3 查看训练数据的前9张图片.

plt.figure(figsize=(6, 6))
for imgs, labels in train_ds.take(1):for i in range(9):ax = plt.subplot(3,3,i+1)plt.imshow(imgs[i].numpy().astype("uint8"))plt.title(class_names[labels[i]])plt.axis("off")

(二) 通过 tf 的基础类,自定义模型

class Mymodel(tf.keras.Model):def __init__(self):super().__init__()# 定义normalization 层self.normalization_layer = tf.keras.layers.experimental.preprocessing.Rescaling(1.0 / 255)# 定义数据增强层self.aug1 = tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal')self.aug2 = tf.keras.layers.experimental.preprocessing.RandomRotation(0.1)self.aug3 = tf.keras.layers.experimental.preprocessing.RandomZoom(0.1)# 定义cov1self.cov1 = tf.keras.layers.Conv2D(16, (3,3), padding='same', activation='relu', name='cov1')self.pool1 = tf.keras.layers.MaxPool2D(name='pool1')# 定义cov2self.cov2 = tf.keras.layers.Conv2D(32, (3,3), padding='same', activation='relu', name='cov2')self.pool2 = tf.keras.layers.MaxPool2D(name='pool2')# 定义cov3self.cov3 = tf.keras.layers.Conv2D(64, (3,3), padding='same', activation='relu', name='cov3')self.pool3 = tf.keras.layers.MaxPool2D(name='pool3')# 定义 Dropoutself.dropout = tf.keras.layers.Dropout(0.2)# 定义 flattenself.flatten = tf.keras.layers.Flatten()# 定义 Denseself.dense1 = tf.keras.layers.Dense(128, activation='relu', name='dense1')self.dense2 = tf.keras.layers.Dense(5)def call(self, img):# 执行normalizationX = self.normalization_layer(img)# 执行augX = self.aug1(X)X = self.aug2(X)X = self.aug3(X)X = self.cov1(X)X = self.pool1(X)X = self.cov2(X)X = self.pool2(X)X = self.cov3(X)X = self.pool3(X)X = self.flatten(X)X =  self.dense1(X)X = self.dense2(X)return X

(三) 定义损失函数

def loss(y_true, y_predict):return tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)(y_true, y_predict)

(四) 定义优化函数

optimizer = tf.keras.optimizers.Adam()

(五) 定义训练函数

def train_step(batch_inp, batch_targ, model):with tf.GradientTape() as tape:dense_ = model(batch_inp)batch_loss = loss(batch_targ, dense_)gradients = tape.gradient(batch_loss, model.trainable_variables)optimizer.apply_gradients(zip(gradients, model.trainable_variables))return batch_loss

(六) 训练模型

# 实例化模型
model = Mymodel()epochs = 50  # 训练50个epoch
els = [] # 存储每个epoch的损失函数,用于后续绘图
for epoch in range(epochs):epoch_loss = 0# 由于我的计算机显存太小,这里每个epoch只取前20个batch进行训练for batch, (inp, targ) in enumerate(train_ds.take(20)):batch_loss = train_step(inp, targ, model)epoch_loss += batch_loss.numpy()print('epoch {}: {:.4f}'.format(epoch, epoch_loss/10))els.append(epoch_loss/10)

训练过程如下:

epoch 0: 0.5867
epoch 1: 0.6709
epoch 2: 0.6393
epoch 3: 0.6831
epoch 4: 0.6870
epoch 5: 0.6461
epoch 6: 0.4888

Loss 随训练过程的变化情况:

(七) 通过模型进行预测

预测的代码来之 TensorFlow 官方社区。

sunflower_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/592px-Red_sunflower.jpg"
sunflower_path = tf.keras.utils.get_file('Red_sunflower', origin=sunflower_url)img = keras.preprocessing.image.load_img(sunflower_path, target_size=(img_height, img_width)
)
img_array = keras.preprocessing.image.img_to_array(img)
img_array = tf.expand_dims(img_array, 0) # Create a batchpredictions = model.predict(img_array)
score = tf.nn.softmax(predictions[0])print("This image most likely belongs to {} with a {:.2f} percent confidence.".format(class_names[np.argmax(score)], 100 * np.max(score))
)

图片为:

预测结果:

This image most likely belongs to sunflowers with a 97.69 percent confidence.

通过 Tensorflow 的基础类,构建卷积神经网络,用于花朵图片的分类相关推荐

  1. 吴裕雄--天生自然 Tensorflow卷积神经网络:花朵图片识别

    import os import numpy as np import matplotlib.pyplot as plt from PIL import Image, ImageChops from ...

  2. tensorflow 图像教程 の TF Layers 教程:构建卷积神经网络

    文章目录 TF Layers 教程:构建卷积神经网络 卷积神经网络的简介 构建基于卷积神经网络的 MNIST 分类器 输入层 第一个卷积层 第一个池化层 第二个卷积层和池化层 全连接层 Logits ...

  3. keras构建卷积神经网络_在python中使用tensorflow s keras api构建卷积神经网络的初学者指南...

    keras构建卷积神经网络 初学者的深度学习 (DEEP LEARNING FOR BEGINNERS) Welcome to Part 2 of the Neural Network series! ...

  4. TF之CNN:Tensorflow构建卷积神经网络CNN的简介、使用方法、应用之详细攻略

    TF之CNN:Tensorflow构建卷积神经网络CNN的简介.使用方法.应用之详细攻略 目录 TensorFlow 中的卷积有关函数入门 1.tf.nn.conv2d函数 案例应用 1.TF之CNN ...

  5. keras构建卷积神经网络_在Keras中构建,加载和保存卷积神经网络

    keras构建卷积神经网络 This article is aimed at people who want to learn or review how to build a basic Convo ...

  6. TensorFlow(7)卷积神经网络实战(1)(可视化)

    目录 基础理论 卷积 卷积核与过滤器的区别 一.获取数据集 二.设定数据集大小.归一化 三.构建卷积神经网络 四.编译&&训练 五.模型评估 六.可视化 1.创建plt图 2.获取各卷 ...

  7. PyTorch基础与简单应用:构建卷积神经网络实现MNIST手写数字分类

    文章目录 (一) 问题描述 (二) 设计简要描述 (三) 程序清单 (四) 结果分析 (五) 调试报告 (六) 实验小结 (七) 参考资料 (一) 问题描述 构建卷积神经网络实现MNIST手写数字分类 ...

  8. keras构建卷积神经网络_通过此简单教程学习在网络上构建卷积神经网络

    keras构建卷积神经网络 by John David Chibuk 约翰·大卫·奇布克(John David Chibuk) 通过此简单教程学习在网络上构建卷积神经网络 (Learn to buil ...

  9. tensorflow预定义经典卷积神经网络和数据集tf.keras.applications

    自己开发了一个股票软件,功能很强大,需要的点击下面的链接获取: https://www.cnblogs.com/bclshuai/p/11380657.html 1.1  tensorflow预定义经 ...

最新文章

  1. 谷歌 notification 测试 页面
  2. 在客户端调用MOSS的搜索服务,实现更加灵活的搜索控制
  3. 最小最大定理 java_Java基础50道经典练习题(35)——最大最小交换
  4. git rebase(变基)—— Git 学习笔记 19
  5. 演讲者模式投影到幕布也看到备注_家用投影幕布怎么选?(看这一篇就明白了)...
  6. java 池化_溯本求源: JAVA线程池工作原理
  7. 我的世界java版gamemode指令_【服务器相关】【求助!】关于服务器中使用gamemode等命令错误。...
  8. 火星地形地貌图,摄影:“祝融号”火星车
  9. java线程死锁_Java线程死锁实例及解决方法
  10. 使用SQLyog创建MySQL数据库
  11. Aspose.Word 操作word表格的行 插入行 添加行
  12. Linux网络技术学习(二)—— net_device数据结构解析
  13. 生活记录:压抑暂时解脱
  14. 容器高度或者宽度的获取方式
  15. 中山大学数学科学与计算机科学,中山大学数学与计算科学学院导师介绍:邹青松...
  16. OpenCV中图像特征提取与描述
  17. RHCSA 2022/10/14
  18. 使用BackTrack来增强电脑的安全
  19. 通过NFS服务器将设备目录挂载到Windows目录
  20. 中国涡轮盘拉床市场现状研究分析与发展前景预测报告(2022)

热门文章

  1. 小米3联通电信通刷_2013062 2013063_官方线刷包_救砖包_解账户锁
  2. 25、搞懂闭包、作用域、执行期上下文(VO、AO)、作用域链
  3. 十大疯狂营销的公司:苹果居首
  4. Floyd算法求解最短距离
  5. 视频直播点播EasyDSS互联网视频云平台虚拟直播Avfilter流阻塞情况的优化
  6. 事务(一)——什么是事务,为什么会有事务,事务是做什么的?
  7. Cifar10完整模型搭建
  8. 从别人库里拷贝的游戏如何再自己的库里显示
  9. 假设检验 python_假坏(喻言时)最新章节-假坏小说全文免费阅读-看书迷
  10. 华创期货:日内交易简单方法有效规避亏损