import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf
from tensorflow.keras.preprocessing import image_dataset_from_directory#加载数据
PATH = 'D:\data\hymenoptera_data'
train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'val')BATCH_SIZE = 32
IMG_SIZE = (160, 160)train_dataset = image_dataset_from_directory(train_dir,shuffle=True,batch_size=BATCH_SIZE,image_size=IMG_SIZE)validation_dataset = image_dataset_from_directory(validation_dir,shuffle=True,batch_size=BATCH_SIZE,image_size=IMG_SIZE)#显示训练集中的前九个图像和标签:
class_names = train_dataset.class_names
%config InlineBackend.figure_format = 'retina'
plt.figure(figsize=(10, 10))
for images, labels in train_dataset.take(1):for i in range(9):ax = plt.subplot(3, 3, i + 1)plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[labels[i]])plt.axis("off")#使用缓冲预提取从磁盘加载图像,以免造成 I/O 阻塞
AUTOTUNE = tf.data.AUTOTUNEtrain_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE)
validation_dataset = validation_dataset.prefetch(buffer_size=AUTOTUNE)"""
当您没有较大的图像数据集时,最好将随机但现实的转换应用于训练图像(例如旋转或水平翻转)
来人为引入样本多样性。这有助于使模型暴露于训练数据的不同方面并减少过拟合
"""
data_augmentation = tf.keras.Sequential([tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal'),tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
])"""
注:当您调用 model.fit 时,这些层仅在训练过程中才会处于有效状态。
在 model.evaulate 或 model.fit 中的推断模式下使用模型时,它们处于停用状态。我们将这些层重复应用于同一个图像,然后查看结果。
"""
for image, _ in train_dataset.take(1):plt.figure(figsize=(10, 10))first_image = image[0]for i in range(9):ax = plt.subplot(3, 3, i + 1)augmented_image = data_augmentation(tf.expand_dims(first_image, 0))plt.imshow(augmented_image[0] / 255)plt.axis('off')preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input
prediction_layer = tf.keras.layers.Dense(2)# Create the base model from the pre-trained model MobileNet V2
base_model = tf.keras.applications.MobileNetV2(input_shape=(160, 160, 3),include_top=False,weights='imagenet')
base_model.trainable = False"""
通过使用 Keras 函数式 API 将数据扩充、重新缩放、
base_model 和特征提取程序层链接在一起来构建模型。
如前面所述,由于我们的模型包含 BatchNormalization 层,
因此请使用 training = False。
"""
inputs = tf.keras.Input(shape=(160, 160, 3))
x = data_augmentation(inputs)
x = preprocess_input(x)
x = base_model(x, training=False)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = prediction_layer(x)
model = tf.keras.Model(inputs, outputs)base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.Adam(),loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])model.summary()#经过 10 个周期的训练后,您应该在验证集上看到约 94% 的准确率。
initial_epochs = 10
history = model.fit(train_dataset,epochs=initial_epochs,validation_data=validation_dataset)#我们看一下使用 MobileNet V2 基础模型作为固定特征提取程序时训练和验证准确率/损失的学习曲线。
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1])
plt.title('Training and Validation Accuracy')plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([0,1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()"""
您需要做的是解冻 base_model 并将底层设置为不可训练。
随后,您应该重新编译模型(使这些更改生效的必需操作),然后恢复训练。
"""
base_model.trainable = True
model.compile(loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),optimizer = tf.keras.optimizers.Adam(lr=base_learning_rate/10),metrics=['accuracy'])"""
当您正在训练一个大得多的模型并且想要重新调整预训练权重时,
请务必在此阶段使用较低的学习率。否则,您的模型可能会很快过拟合。
"""
fine_tune_epochs = 10
total_epochs =  initial_epochs + fine_tune_epochshistory_fine = model.fit(train_dataset,epochs=total_epochs,initial_epoch=history.epoch[-1],validation_data=validation_dataset)acc += history_fine.history['accuracy']
val_acc += history_fine.history['val_accuracy']loss += history_fine.history['loss']
val_loss += history_fine.history['val_loss']plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.ylim([0.8, 1])
plt.plot([initial_epochs-1,initial_epochs-1],plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.ylim([0, 1.0])
plt.plot([initial_epochs-1,initial_epochs-1],plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()
Model: "model"
_________________________________________________________________Layer (type)                Output Shape              Param #
=================================================================input_2 (InputLayer)        [(None, 160, 160, 3)]     0         sequential (Sequential)     (None, 160, 160, 3)       0         tf.math.truediv (TFOpLambda  (None, 160, 160, 3)      0         )                                                               tf.math.subtract (TFOpLambd  (None, 160, 160, 3)      0         a)                                                              mobilenetv2_1.00_160 (Funct  (None, 5, 5, 1280)       2257984   ional)                                                          global_average_pooling2d (G  (None, 1280)             0         lobalAveragePooling2D)                                          dropout (Dropout)           (None, 1280)              0         dense (Dense)               (None, 2)                 2562      =================================================================
Total params: 2,260,546
Trainable params: 2,562
Non-trainable params: 2,257,984
_________________________________________________________________
Epoch 1/10
8/8 [==============================] - 7s 580ms/step - loss: 0.5505 - accuracy: 0.7224 - val_loss: 0.2987 - val_accuracy: 0.9020
Epoch 2/10
8/8 [==============================] - 4s 490ms/step - loss: 0.3317 - accuracy: 0.8653 - val_loss: 0.2101 - val_accuracy: 0.9281
Epoch 3/10
8/8 [==============================] - 4s 504ms/step - loss: 0.2429 - accuracy: 0.9184 - val_loss: 0.1722 - val_accuracy: 0.9281
Epoch 4/10
8/8 [==============================] - 4s 502ms/step - loss: 0.1721 - accuracy: 0.9306 - val_loss: 0.1540 - val_accuracy: 0.9346
Epoch 5/10
8/8 [==============================] - 4s 516ms/step - loss: 0.1451 - accuracy: 0.9347 - val_loss: 0.1452 - val_accuracy: 0.9412
Epoch 6/10
8/8 [==============================] - 4s 492ms/step - loss: 0.1365 - accuracy: 0.9469 - val_loss: 0.1409 - val_accuracy: 0.9412
Epoch 7/10
8/8 [==============================] - 4s 504ms/step - loss: 0.1333 - accuracy: 0.9429 - val_loss: 0.1329 - val_accuracy: 0.9477
Epoch 8/10
8/8 [==============================] - 4s 497ms/step - loss: 0.1015 - accuracy: 0.9755 - val_loss: 0.1300 - val_accuracy: 0.9477
Epoch 9/10
8/8 [==============================] - 4s 502ms/step - loss: 0.1148 - accuracy: 0.9633 - val_loss: 0.1227 - val_accuracy: 0.9477
Epoch 10/10
8/8 [==============================] - 4s 520ms/step - loss: 0.0964 - accuracy: 0.9673 - val_loss: 0.1228 - val_accuracy: 0.9477
8/8 [==============================] - 15s 1s/step - loss: 0.0929 - accuracy: 0.9796 - val_loss: 0.1369 - val_accuracy: 0.9412
Epoch 11/20
8/8 [==============================] - 11s 1s/step - loss: 0.0687 - accuracy: 0.9714 - val_loss: 0.1502 - val_accuracy: 0.9412
Epoch 12/20
8/8 [==============================] - 10s 1s/step - loss: 0.0662 - accuracy: 0.9714 - val_loss: 0.1573 - val_accuracy: 0.9477
Epoch 13/20
8/8 [==============================] - 10s 1s/step - loss: 0.0662 - accuracy: 0.9796 - val_loss: 0.1709 - val_accuracy: 0.9412
Epoch 14/20
8/8 [==============================] - 10s 1s/step - loss: 0.0650 - accuracy: 0.9714 - val_loss: 0.1629 - val_accuracy: 0.9477
Epoch 15/20
8/8 [==============================] - 10s 1s/step - loss: 0.0526 - accuracy: 0.9755 - val_loss: 0.1609 - val_accuracy: 0.9477
Epoch 16/20
8/8 [==============================] - 10s 1s/step - loss: 0.0448 - accuracy: 0.9878 - val_loss: 0.1526 - val_accuracy: 0.9281
Epoch 17/20
8/8 [==============================] - 11s 1s/step - loss: 0.0283 - accuracy: 0.9959 - val_loss: 0.1495 - val_accuracy: 0.9281
Epoch 18/20
8/8 [==============================] - 11s 1s/step - loss: 0.0329 - accuracy: 0.9918 - val_loss: 0.1819 - val_accuracy: 0.9346
Epoch 19/20
8/8 [==============================] - 11s 1s/step - loss: 0.0293 - accuracy: 0.9959 - val_loss: 0.1484 - val_accuracy: 0.9346
Epoch 20/20
8/8 [==============================] - 11s 1s/step - loss: 0.0159 - accuracy: 1.0000 - val_loss: 0.1410 - val_accuracy: 0.9346

蜜蜂蚂蚁数据集分类(tensorflow)相关推荐

  1. pytorch蜜蜂蚂蚁数据集处理python代码

    代码实现效果: 代码: import os root_dir = r"C:\Users\ninan\Pytorch_study\torch_study\dataset\train" ...

  2. 二隐层的神经网络实现MNIST数据集分类

    二隐层的神经网络实现MNIST数据集分类 传统的人工神经网络包含三部分,输入层.隐藏层和输出层.对于一个神经网络模型的确定需要考虑以下几个方面: 隐藏层的层数以及各层的神经元数量 各层激活函数的选择 ...

  3. 一层神经网络实现鸢尾花数据集分类

    一层神经网络实现鸢尾花数据集分类 1.数据集介绍 2.程序实现 2.1 数据集导入 2.2 数据集乱序 2.3 数据集划分成永不相见的训练集和测试集 3.4 配成[输入特征,标签]对,每次喂入一小撮( ...

  4. tensorflow2.0莺尾花iris数据集分类|超详细

    tensorflow2.0莺尾花iris数据集分类 超详细 直接上代码 #导入模块 import tensorflow as tf #导入tensorflow模块from sklearn import ...

  5. 如何制作自己的图片数据集-基于tensorflow

    写在开始 自己最开始接触python的时候,第一个学会使用的库就是tensorflow,在经历了everyone 都会经历的mnist数据集训练后,就开始想自己做一个图片分类的深度学习,期间也是一波三 ...

  6. 使用ResNet18网络实现对Cifar-100数据集分类

    使用ResNet18网络实现对Cifar-100数据集分类 简介 本次作业旨在利用ResNet18实现对于Cifar-100数据集进行图像识别按照精细类进行分类. Cifar-100数据集由20个粗类 ...

  7. [Python图像识别] 五十.Keras构建AlexNet和CNN实现自定义数据集分类详解

    该系列文章是讲解Python OpenCV图像处理知识,前期主要讲解图像入门.OpenCV基础用法,中期讲解图像处理的各种算法,包括图像锐化算子.图像增强技术.图像分割等,后期结合深度学习研究图像识别 ...

  8. (决策树,朴素贝叶斯,人工神经网络)实现鸢尾花数据集分类

    from sklearn.datasets import load_iris # 导入方法类iris = load_iris() #导入数据集iris iris_feature = iris.data ...

  9. caffe学习笔记18-image1000test200数据集分类与检索完整过程

    image1000test200数据集分类与检索完整过程: 1.准备数据:数据集图片分10个类,每个类有100个train图片(train文件夹下,一共1000),20个test图片(val文件夹下, ...

最新文章

  1. ssm中怎么使tomcat一起动就执行一个controller_【200期】面试官:你能简单说说 SpringMVC 的执行原理吗?...
  2. java实验 输入输出流_java实验七 输入输出流
  3. 男女薪酬差异扩大 2018年女性薪酬不及男性8成
  4. Squid反向代理加速缓存+负载均衡实验架构
  5. python基础(18)之 异常处理
  6. eclipse离线安装Activiti Designer插件
  7. Matlab之矩阵行列式、秩、迹的求解
  8. 计算机辅助初中数学教学,初中数学教学论文 计算机辅助农村初中数学教学的几点想法...
  9. AJAX学习笔记 一:简单的XMLHTTPRequest示例和asp.net异步更新。
  10. 一个简单的form表单登录界面
  11. 可视化工具Netron介绍
  12. linux运行崩溃怎么定位,Linux 程序崩溃定位
  13. Illegal character: U+00A0
  14. 高斯消元法的c语言编程,用C语言编程高斯全主元消元法
  15. 程序员常用的经典算法和OJ网站
  16. 三、pytest接口自动化之pytest中setup/teardown,setup_class/teardown_class讲解
  17. 中水处理设备:中水处理工艺流程的选择
  18. 转发--2022新型冠状病毒肺炎诊疗方案(试行第九版)-中医治疗部分
  19. 阿里2020.4.1实习笔试题——攻击怪兽
  20. 巧妙解决百度云管家下载速度慢

热门文章

  1. python调用Matlab函数
  2. 硬件信息获取--DMI
  3. php中的switch判断妙用
  4. AD18设置默认线宽
  5. 进行性肌营养不良研究又有新的发现
  6. [答疑]充值卡的状态图
  7. 关于车联网系统设计思路(一)
  8. Java socket详解,看这一篇就够了
  9. 史上嫁不出去的公主都有谁?
  10. 19.Docker技术入门与实战 --- 安全防护与配置