蜜蜂蚂蚁数据集分类(tensorflow)
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf
from tensorflow.keras.preprocessing import image_dataset_from_directory#加载数据
PATH = 'D:\data\hymenoptera_data'
train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'val')BATCH_SIZE = 32
IMG_SIZE = (160, 160)train_dataset = image_dataset_from_directory(train_dir,shuffle=True,batch_size=BATCH_SIZE,image_size=IMG_SIZE)validation_dataset = image_dataset_from_directory(validation_dir,shuffle=True,batch_size=BATCH_SIZE,image_size=IMG_SIZE)#显示训练集中的前九个图像和标签:
class_names = train_dataset.class_names
%config InlineBackend.figure_format = 'retina'
plt.figure(figsize=(10, 10))
for images, labels in train_dataset.take(1):for i in range(9):ax = plt.subplot(3, 3, i + 1)plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[labels[i]])plt.axis("off")#使用缓冲预提取从磁盘加载图像,以免造成 I/O 阻塞
AUTOTUNE = tf.data.AUTOTUNEtrain_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE)
validation_dataset = validation_dataset.prefetch(buffer_size=AUTOTUNE)"""
当您没有较大的图像数据集时,最好将随机但现实的转换应用于训练图像(例如旋转或水平翻转)
来人为引入样本多样性。这有助于使模型暴露于训练数据的不同方面并减少过拟合
"""
data_augmentation = tf.keras.Sequential([tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal'),tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
])"""
注:当您调用 model.fit 时,这些层仅在训练过程中才会处于有效状态。
在 model.evaulate 或 model.fit 中的推断模式下使用模型时,它们处于停用状态。我们将这些层重复应用于同一个图像,然后查看结果。
"""
for image, _ in train_dataset.take(1):plt.figure(figsize=(10, 10))first_image = image[0]for i in range(9):ax = plt.subplot(3, 3, i + 1)augmented_image = data_augmentation(tf.expand_dims(first_image, 0))plt.imshow(augmented_image[0] / 255)plt.axis('off')preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input
prediction_layer = tf.keras.layers.Dense(2)# Create the base model from the pre-trained model MobileNet V2
base_model = tf.keras.applications.MobileNetV2(input_shape=(160, 160, 3),include_top=False,weights='imagenet')
base_model.trainable = False"""
通过使用 Keras 函数式 API 将数据扩充、重新缩放、
base_model 和特征提取程序层链接在一起来构建模型。
如前面所述,由于我们的模型包含 BatchNormalization 层,
因此请使用 training = False。
"""
inputs = tf.keras.Input(shape=(160, 160, 3))
x = data_augmentation(inputs)
x = preprocess_input(x)
x = base_model(x, training=False)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = prediction_layer(x)
model = tf.keras.Model(inputs, outputs)base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.Adam(),loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])model.summary()#经过 10 个周期的训练后,您应该在验证集上看到约 94% 的准确率。
initial_epochs = 10
history = model.fit(train_dataset,epochs=initial_epochs,validation_data=validation_dataset)#我们看一下使用 MobileNet V2 基础模型作为固定特征提取程序时训练和验证准确率/损失的学习曲线。
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1])
plt.title('Training and Validation Accuracy')plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([0,1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()"""
您需要做的是解冻 base_model 并将底层设置为不可训练。
随后,您应该重新编译模型(使这些更改生效的必需操作),然后恢复训练。
"""
base_model.trainable = True
model.compile(loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),optimizer = tf.keras.optimizers.Adam(lr=base_learning_rate/10),metrics=['accuracy'])"""
当您正在训练一个大得多的模型并且想要重新调整预训练权重时,
请务必在此阶段使用较低的学习率。否则,您的模型可能会很快过拟合。
"""
fine_tune_epochs = 10
total_epochs = initial_epochs + fine_tune_epochshistory_fine = model.fit(train_dataset,epochs=total_epochs,initial_epoch=history.epoch[-1],validation_data=validation_dataset)acc += history_fine.history['accuracy']
val_acc += history_fine.history['val_accuracy']loss += history_fine.history['loss']
val_loss += history_fine.history['val_loss']plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.ylim([0.8, 1])
plt.plot([initial_epochs-1,initial_epochs-1],plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.ylim([0, 1.0])
plt.plot([initial_epochs-1,initial_epochs-1],plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()
Model: "model" _________________________________________________________________Layer (type) Output Shape Param # =================================================================input_2 (InputLayer) [(None, 160, 160, 3)] 0 sequential (Sequential) (None, 160, 160, 3) 0 tf.math.truediv (TFOpLambda (None, 160, 160, 3) 0 ) tf.math.subtract (TFOpLambd (None, 160, 160, 3) 0 a) mobilenetv2_1.00_160 (Funct (None, 5, 5, 1280) 2257984 ional) global_average_pooling2d (G (None, 1280) 0 lobalAveragePooling2D) dropout (Dropout) (None, 1280) 0 dense (Dense) (None, 2) 2562 ================================================================= Total params: 2,260,546 Trainable params: 2,562 Non-trainable params: 2,257,984 _________________________________________________________________
Epoch 1/10 8/8 [==============================] - 7s 580ms/step - loss: 0.5505 - accuracy: 0.7224 - val_loss: 0.2987 - val_accuracy: 0.9020 Epoch 2/10 8/8 [==============================] - 4s 490ms/step - loss: 0.3317 - accuracy: 0.8653 - val_loss: 0.2101 - val_accuracy: 0.9281 Epoch 3/10 8/8 [==============================] - 4s 504ms/step - loss: 0.2429 - accuracy: 0.9184 - val_loss: 0.1722 - val_accuracy: 0.9281 Epoch 4/10 8/8 [==============================] - 4s 502ms/step - loss: 0.1721 - accuracy: 0.9306 - val_loss: 0.1540 - val_accuracy: 0.9346 Epoch 5/10 8/8 [==============================] - 4s 516ms/step - loss: 0.1451 - accuracy: 0.9347 - val_loss: 0.1452 - val_accuracy: 0.9412 Epoch 6/10 8/8 [==============================] - 4s 492ms/step - loss: 0.1365 - accuracy: 0.9469 - val_loss: 0.1409 - val_accuracy: 0.9412 Epoch 7/10 8/8 [==============================] - 4s 504ms/step - loss: 0.1333 - accuracy: 0.9429 - val_loss: 0.1329 - val_accuracy: 0.9477 Epoch 8/10 8/8 [==============================] - 4s 497ms/step - loss: 0.1015 - accuracy: 0.9755 - val_loss: 0.1300 - val_accuracy: 0.9477 Epoch 9/10 8/8 [==============================] - 4s 502ms/step - loss: 0.1148 - accuracy: 0.9633 - val_loss: 0.1227 - val_accuracy: 0.9477 Epoch 10/10 8/8 [==============================] - 4s 520ms/step - loss: 0.0964 - accuracy: 0.9673 - val_loss: 0.1228 - val_accuracy: 0.9477
8/8 [==============================] - 15s 1s/step - loss: 0.0929 - accuracy: 0.9796 - val_loss: 0.1369 - val_accuracy: 0.9412 Epoch 11/20 8/8 [==============================] - 11s 1s/step - loss: 0.0687 - accuracy: 0.9714 - val_loss: 0.1502 - val_accuracy: 0.9412 Epoch 12/20 8/8 [==============================] - 10s 1s/step - loss: 0.0662 - accuracy: 0.9714 - val_loss: 0.1573 - val_accuracy: 0.9477 Epoch 13/20 8/8 [==============================] - 10s 1s/step - loss: 0.0662 - accuracy: 0.9796 - val_loss: 0.1709 - val_accuracy: 0.9412 Epoch 14/20 8/8 [==============================] - 10s 1s/step - loss: 0.0650 - accuracy: 0.9714 - val_loss: 0.1629 - val_accuracy: 0.9477 Epoch 15/20 8/8 [==============================] - 10s 1s/step - loss: 0.0526 - accuracy: 0.9755 - val_loss: 0.1609 - val_accuracy: 0.9477 Epoch 16/20 8/8 [==============================] - 10s 1s/step - loss: 0.0448 - accuracy: 0.9878 - val_loss: 0.1526 - val_accuracy: 0.9281 Epoch 17/20 8/8 [==============================] - 11s 1s/step - loss: 0.0283 - accuracy: 0.9959 - val_loss: 0.1495 - val_accuracy: 0.9281 Epoch 18/20 8/8 [==============================] - 11s 1s/step - loss: 0.0329 - accuracy: 0.9918 - val_loss: 0.1819 - val_accuracy: 0.9346 Epoch 19/20 8/8 [==============================] - 11s 1s/step - loss: 0.0293 - accuracy: 0.9959 - val_loss: 0.1484 - val_accuracy: 0.9346 Epoch 20/20 8/8 [==============================] - 11s 1s/step - loss: 0.0159 - accuracy: 1.0000 - val_loss: 0.1410 - val_accuracy: 0.9346
蜜蜂蚂蚁数据集分类(tensorflow)相关推荐
- pytorch蜜蜂蚂蚁数据集处理python代码
代码实现效果: 代码: import os root_dir = r"C:\Users\ninan\Pytorch_study\torch_study\dataset\train" ...
- 二隐层的神经网络实现MNIST数据集分类
二隐层的神经网络实现MNIST数据集分类 传统的人工神经网络包含三部分,输入层.隐藏层和输出层.对于一个神经网络模型的确定需要考虑以下几个方面: 隐藏层的层数以及各层的神经元数量 各层激活函数的选择 ...
- 一层神经网络实现鸢尾花数据集分类
一层神经网络实现鸢尾花数据集分类 1.数据集介绍 2.程序实现 2.1 数据集导入 2.2 数据集乱序 2.3 数据集划分成永不相见的训练集和测试集 3.4 配成[输入特征,标签]对,每次喂入一小撮( ...
- tensorflow2.0莺尾花iris数据集分类|超详细
tensorflow2.0莺尾花iris数据集分类 超详细 直接上代码 #导入模块 import tensorflow as tf #导入tensorflow模块from sklearn import ...
- 如何制作自己的图片数据集-基于tensorflow
写在开始 自己最开始接触python的时候,第一个学会使用的库就是tensorflow,在经历了everyone 都会经历的mnist数据集训练后,就开始想自己做一个图片分类的深度学习,期间也是一波三 ...
- 使用ResNet18网络实现对Cifar-100数据集分类
使用ResNet18网络实现对Cifar-100数据集分类 简介 本次作业旨在利用ResNet18实现对于Cifar-100数据集进行图像识别按照精细类进行分类. Cifar-100数据集由20个粗类 ...
- [Python图像识别] 五十.Keras构建AlexNet和CNN实现自定义数据集分类详解
该系列文章是讲解Python OpenCV图像处理知识,前期主要讲解图像入门.OpenCV基础用法,中期讲解图像处理的各种算法,包括图像锐化算子.图像增强技术.图像分割等,后期结合深度学习研究图像识别 ...
- (决策树,朴素贝叶斯,人工神经网络)实现鸢尾花数据集分类
from sklearn.datasets import load_iris # 导入方法类iris = load_iris() #导入数据集iris iris_feature = iris.data ...
- caffe学习笔记18-image1000test200数据集分类与检索完整过程
image1000test200数据集分类与检索完整过程: 1.准备数据:数据集图片分10个类,每个类有100个train图片(train文件夹下,一共1000),20个test图片(val文件夹下, ...
最新文章
- ssm中怎么使tomcat一起动就执行一个controller_【200期】面试官:你能简单说说 SpringMVC 的执行原理吗?...
- java实验 输入输出流_java实验七 输入输出流
- 男女薪酬差异扩大 2018年女性薪酬不及男性8成
- Squid反向代理加速缓存+负载均衡实验架构
- python基础(18)之 异常处理
- eclipse离线安装Activiti Designer插件
- Matlab之矩阵行列式、秩、迹的求解
- 计算机辅助初中数学教学,初中数学教学论文 计算机辅助农村初中数学教学的几点想法...
- AJAX学习笔记 一:简单的XMLHTTPRequest示例和asp.net异步更新。
- 一个简单的form表单登录界面
- 可视化工具Netron介绍
- linux运行崩溃怎么定位,Linux 程序崩溃定位
- Illegal character: U+00A0
- 高斯消元法的c语言编程,用C语言编程高斯全主元消元法
- 程序员常用的经典算法和OJ网站
- 三、pytest接口自动化之pytest中setup/teardown,setup_class/teardown_class讲解
- 中水处理设备:中水处理工艺流程的选择
- 转发--2022新型冠状病毒肺炎诊疗方案(试行第九版)-中医治疗部分
- 阿里2020.4.1实习笔试题——攻击怪兽
- 巧妙解决百度云管家下载速度慢