[tensorflow2笔记十] 搭建卷积网络实现猫狗图片分类
文章目录
- 1.制作数据集
- 2.搭建网络训练
- 3.输入图片测试
1.制作数据集
(1)下载数据集。从网上下载kaggle猫狗分类的数据集,为缩短训练时间,选择2000张图片(猫狗各1000张)作为训练集,200张图片(猫狗各100张)作为测试集。在train文件夹选0-1999的猫和0-1999的狗作为训练集,选2000-2099的猫和2000-2099的狗作为测试集。
(2)调整图片的大小。图片大小不一,需要调整图片的大小,重新设定规格(244,244,3),从而在后续的网络模型输入时,保证输入到模型中的图片大小一致。
##### resize_data.py #####import os
import cv2dir_train = "/home/xiaobin/PycharmProjects/figure/train_1000"
dir_test = "/home/xiaobin/PycharmProjects/figure/test_200"for root, dirs, files in os.walk(dir_test):for file in files:filepath = os.path.join(root, file)image = cv2.imread(filepath)dim = (224, 224)resized = cv2.resize(image, dim)path = "/home/xiaobin/PycharmProjects/figure/test/" + filecv2.imwrite(path, resized)# os.walk() 方法是一个简单易用的文件、目录遍历器
# root 所指的是当前正在遍历的这个文件夹的本身的地址
# dirs 是一个list ,内容是该文件夹中所有的目录的名字(不包括子目录)
# files 同样是list , 内容是该文件夹中所有的文件(不包括子目录)
(3)制作标签文档。为让图片和标签匹配,制作训练集和测试集图片的索引文本。编写代码实现:
##### make_txt.py ########f1 = open("train.txt", 'w')
for i in range(1000):f1.write("cat.%d.jpg %d\n" % (i, 0))
for j in range(1000):f1.write("dog.%d.jpg %d\n" % (j, 1))
f1.close()f2 = open("test.txt", "w")
for i in range(100):f2.write("cat.%d.jpg %d\n" % (i+2000, 0))
for j in range(100):f2.write("dog.%d.jpg %d\n" % (j+2000, 1))
f2.close()
2.搭建网络训练
# 1.导入一些模块
import cv2
import tensorflow as tf
import numpy as np # 用于数据格式转换
import os # 路径
from tensorflow.keras.preprocessing.image import ImageDataGenerator # 数据增强
from matplotlib import pyplot as plt
from tensorflow.keras.layers import Dense,Activation,Dropout,Conv2D,BatchNormalization,MaxPool2D,Flatten# 设置GPU显存按需申请
gpu = tf.config.experimental.list_physical_devices(device_type='GPU')
assert len(gpu) == 1
tf.config.experimental.set_memory_growth(gpu[0], True)# 2.路径和存储文件
train_path = './train_2000/' # 训练集图片路径
train_txt = 'train.txt' # 训练集标签文件test_path = './test_200/' # 测试集图片路径
test_txt = 'test.txt' # 测试集标签文件# 3.制作数据集的函数
def generateds(path, txt): # 图片路径,标签文件f = open(txt, 'r') # 以只读的形式打开txtcontents = f.readlines() # 读取文件中所有的行,每行为一个单位f.close()x, y_ = [], []for content in contents: # 逐行读出value = content.split() # 以空格分开img_path = path + value[0]img = cv2.imread(img_path)img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # 转为RGBimg = img / 255.0 # 归一化,有利于网络吸收x.append(img)y_.append(value[1])# print('load:' + content)x = np.array(x)y_ = np.array(y_)y_ = y_.astype(np.int64)return x, y_# 4.加载数据
print('----------------Generate Datasets--------------')
x_train, y_train = generateds(train_path, train_txt)
x_test, y_test = generateds(test_path, test_txt)# 5.打乱数据集样本顺序
index = [i for i in range(len(x_train))]
np.random.shuffle(index)
x_train = x_train[index]
y_train = y_train[index]index1 = [j for j in range(len(x_test))]
np.random.shuffle(index1)
x_test = x_test[index1]
y_test = y_test[index1]# 6.数据增强
image_gen_train = ImageDataGenerator(rescale=1. / 1., # 如果是图像,分母为255,可以归一化到0-1rotation_range=45, # 随机45度旋转width_shift_range=.15, # 宽度偏移height_shift_range=.15, # 高度偏移horizontal_flip=True, # 水平翻转zoom_range=0.5 # 将图像随机缩放阈量50%
)
image_gen_train.fit(x_train)# 7.搭建网络
model = tf.keras.models.Sequential([Conv2D(filters=32, kernel_size=(3, 3)),BatchNormalization(),Activation('relu'),MaxPool2D(pool_size=(2, 2), strides=2),Conv2D(filters=64, kernel_size=(3, 3)),BatchNormalization(),Activation('relu'),MaxPool2D(pool_size=(2, 2), strides=2),Conv2D(filters=128, kernel_size=(3, 3)),BatchNormalization(),Activation('relu'),MaxPool2D(pool_size=(2, 2), strides=2),Flatten(), # 把输入特征拉直为一维数组数值Dense(128, activation='relu'),Dense(64, activation='relu'),Dense(2, activation='softmax')
])# 8.配置参数
model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])# 9.设置调用和保存模型
# 调用模型
checkpoint_save_path = "./checkpoint/cat_dag.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):print("--------------load model--------------")model.load_weights(checkpoint_save_path)
# 保存模型
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,save_weights_only=True,save_best_only=True)# 10.训练
history = model.fit(image_gen_train.flow(x_train, y_train, batch_size=32),epochs=15, validation_data=(x_test, y_test), validation_freq=1,callbacks=[cp_callback])
## 提取acc和loss
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']# 11.打印和保存网络参数
model.summary()
file = open('./weights.txt', 'w')
for v in model.trainable_variables:file.write(str(v.name) + '\n')file.write(str(v.shape) + '\n')file.write(str(v.numpy()) + '\n')
file.close()# 12.绘制acc和loss曲线
plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
训练的正确率在72%左右,loss和acc图如下所示:
3.输入图片测试
import tensorflow as tf
from tensorflow.keras.layers import Dense,Activation,Dropout,Conv2D,BatchNormalization,MaxPool2D,Flatten
import cv2# 设置GPU显存按需申请
gpu = tf.config.experimental.list_physical_devices(device_type='GPU')
assert len(gpu) == 1
tf.config.experimental.set_memory_growth(gpu[0], True)# 1.复现模型(前向传播)
model = tf.keras.models.Sequential([Conv2D(filters=32, kernel_size=(3, 3)),BatchNormalization(),Activation('relu'),MaxPool2D(pool_size=(2, 2), strides=2),Conv2D(filters=64, kernel_size=(3, 3)),BatchNormalization(),Activation('relu'),MaxPool2D(pool_size=(2, 2), strides=2),Conv2D(filters=128, kernel_size=(3, 3)),BatchNormalization(),Activation('relu'),MaxPool2D(pool_size=(2, 2), strides=2),Flatten(), # 把输入特征拉直为一维数组数值Dense(128, activation='relu'),Dense(64, activation='relu'),Dense(2, activation='softmax')
])# 2.加载参数
model_save_path = "./checkpoint/cat_dag.ckpt" # 模型参数存储的路径
model.load_weights(model_save_path)# 3.数据预处理
img = cv2.imread("./train/dog.228.jpg")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # 转为RGB
img = cv2.resize(img, (224, 224))
img = img / 255.0 # 归一化
img = img[tf.newaxis, ...]# 4.预测结果
result = model.predict(img)
pred = tf.argmax(result, axis=1)
if pred.numpy() == 0:print('识别结果是:小猫')
else:print('识别结果是:小狗')
[tensorflow2笔记十] 搭建卷积网络实现猫狗图片分类相关推荐
- 使用预训练的卷积神经网络(猫狗图片分类)
本次所用数据来自ImageNet,使用预训练好的数据来预测一个新的数据集:猫狗图片分类.这里,使用VGG模型,这个模型内置在Keras中,直接导入就可以了. from keras.applicatio ...
- 用卷积神经网络实现猫狗图片分类
该例程使用数据集来源于 kaggle cat_VS _dog 数据集中的一部分, 用卷积神经网络实现猫狗图片二分类,例程序比较简单,就不多解释了,代码中会有相应的注释,直接上代码: import nu ...
- PyTorch搭建预训练AlexNet、DenseNet、ResNet、VGG实现猫狗图片分类
目录 前言 AlexNet DensNet ResNet VGG 前言 在之前的文章中,利用一个简单的三层CNN猫狗图片分类,正确率不高,详见: CNN简单实战:PyTorch搭建CNN对猫狗图片进行 ...
- Top2:CNN 卷积神经网络实现猫狗图片识别二分类
Top2:CNN 卷积神经网络实现猫狗图片识别二分类 系统:Windows10 Professional 环境:python=3.6 tensorflow-gpu=1.14 ```python &qu ...
- 11.CNN实现真实猫狗图片分类
CNN实现真实猫狗图片分类 个人认为,和上一节的mnist数据集里面的手写数字图片不同之处就是,真实的图片更加复杂,像素点更多.因此在对应的图片预处理方面会稍微麻烦一些.但是这个例子能让我们可以处理自 ...
- 体验AI乐趣:基于AI Gallery的二分类猫狗图片分类小数据集自动学习
摘要:直接使用AI Gallery里面现有的数据集进行自动学习训练,很简单和方便,节约时间,不用自己去训练了,AI Gallery 里面有很多类似的有趣数据集,也非常好玩,大家一起试试吧. 本文分享自 ...
- CNN之从头训练一个猫狗图片分类模型
猫狗图片下载地址: 链接:https://pan.baidu.com/s/1l1AnBgkAAEhh0vI5_loWKw 提取码:2xq4 说明:大概有816M大小,分为train和test,trai ...
- [深度之眼]TensorFlow2.0项目班-猫狗图片分类
猫狗数据集官网 猫示例: 狗示例: 训练集:猫狗各11500张图片 验证集:猫狗各1000张图片 难点:图片大小不统一,标签未配对 首先加载需要的包: import tensorflow as tf ...
- 【Keras】 计算机视觉 CNN 实现猫狗图片分类
目录 综述 图像预览 数据预处理 验证集 模型训练 训练结果 综述 本项目旨在通过一个公开数据集,训练一个可以将图片中的猫和狗进行分类的模型. 数据集包括25,000 张训练数据.其中猫和狗的照片各 ...
最新文章
- [JavaScript] 日期时间戳的使用与计算
- SqlDataAdapter.Update批量数据更新
- context:annotation-config vs context:component-scan
- Python 包管理之 poetry
- c语言随机抽取小程序_C语言整人小程序,慎用,谨记!
- 计算机与工程建设项目结合,计算机科学与技术在工程建设项目管理中应用.doc...
- 好看的php验证码,一个漂亮的PHP验证码_PHP教程
- 在nodejs项目装一个库的多个版本
- OpenCV精进之路(零):core组件——绘制点、直线、几何图形
- bootstrap,layui,elementui vantui的区别
- 零基础学SQL(一、数据库与SQL简介)
- MEMS传感市场,美/日/德企占主导地位
- 用pytest实现POM模型
- element中file-upload组件的提示‘按delete键可删除’,怎么去掉
- 大数据到底怎么学:数据科学概论与大数据学习误区
- 最小公倍数用c语言,如何用C语言求最小公倍数。。。
- DxO PhotoLab 2.1.2 for Mac精华汉化版 DxO PhotoLab 2.1.2 for Mac中文版
- 数据同步工具之DataX实操
- 在BOSS直聘发现了一个前端小秘密
- 大数据常见错误解决方案(转载)
热门文章
- 将计算机引入课堂热潮 英语一,浅谈小学英语课堂的计算机辅助教学论文
- JDK1.8新增时间日期API
- 计算机房怎么读英语单词,教学机房,teaching computer lab,音标,读音,翻译,英文例句,英语词典...
- adnroid studio debug模式提示 Method breakpoints may dramatically slow down debugging
- android app技术亮点
- Codeforces Round #626 (Div. 2)
- R2S软路由+夸克网盘实现本地追剧
- 指数加权平均(EWA)
- 让我们拥抱DataV,感受数据可视化的魅力
- 第三章 似有隐情的谈话