文章目录

  • 1.制作数据集
  • 2.搭建网络训练
  • 3.输入图片测试

1.制作数据集

(1)下载数据集。从网上下载kaggle猫狗分类的数据集,为缩短训练时间,选择2000张图片(猫狗各1000张)作为训练集,200张图片(猫狗各100张)作为测试集。在train文件夹选0-1999的猫和0-1999的狗作为训练集,选2000-2099的猫和2000-2099的狗作为测试集。

(2)调整图片的大小。图片大小不一,需要调整图片的大小,重新设定规格(244,244,3),从而在后续的网络模型输入时,保证输入到模型中的图片大小一致。

##### resize_data.py #####import os
import cv2dir_train = "/home/xiaobin/PycharmProjects/figure/train_1000"
dir_test = "/home/xiaobin/PycharmProjects/figure/test_200"for root, dirs, files in os.walk(dir_test):for file in files:filepath = os.path.join(root, file)image = cv2.imread(filepath)dim = (224, 224)resized = cv2.resize(image, dim)path = "/home/xiaobin/PycharmProjects/figure/test/" + filecv2.imwrite(path, resized)# os.walk() 方法是一个简单易用的文件、目录遍历器
# root 所指的是当前正在遍历的这个文件夹的本身的地址
# dirs 是一个list ,内容是该文件夹中所有的目录的名字(不包括子目录)
# files 同样是list , 内容是该文件夹中所有的文件(不包括子目录)

(3)制作标签文档。为让图片和标签匹配,制作训练集和测试集图片的索引文本。编写代码实现:

##### make_txt.py ########f1 = open("train.txt", 'w')
for i in range(1000):f1.write("cat.%d.jpg %d\n" % (i, 0))
for j in range(1000):f1.write("dog.%d.jpg %d\n" % (j, 1))
f1.close()f2 = open("test.txt", "w")
for i in range(100):f2.write("cat.%d.jpg %d\n" % (i+2000, 0))
for j in range(100):f2.write("dog.%d.jpg %d\n" % (j+2000, 1))
f2.close()

2.搭建网络训练

# 1.导入一些模块
import cv2
import tensorflow as tf
import numpy as np          # 用于数据格式转换
import os                   # 路径
from tensorflow.keras.preprocessing.image import ImageDataGenerator     # 数据增强
from matplotlib import pyplot as plt
from tensorflow.keras.layers import Dense,Activation,Dropout,Conv2D,BatchNormalization,MaxPool2D,Flatten# 设置GPU显存按需申请
gpu = tf.config.experimental.list_physical_devices(device_type='GPU')
assert len(gpu) == 1
tf.config.experimental.set_memory_growth(gpu[0], True)# 2.路径和存储文件
train_path = './train_2000/'             # 训练集图片路径
train_txt = 'train.txt'                  # 训练集标签文件test_path = './test_200/'                # 测试集图片路径
test_txt = 'test.txt'                    # 测试集标签文件# 3.制作数据集的函数
def generateds(path, txt):      # 图片路径,标签文件f = open(txt, 'r')          # 以只读的形式打开txtcontents = f.readlines()    # 读取文件中所有的行,每行为一个单位f.close()x, y_ = [], []for content in contents:     # 逐行读出value = content.split()  # 以空格分开img_path = path + value[0]img = cv2.imread(img_path)img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)      # 转为RGBimg = img / 255.0       # 归一化,有利于网络吸收x.append(img)y_.append(value[1])# print('load:' + content)x = np.array(x)y_ = np.array(y_)y_ = y_.astype(np.int64)return x, y_# 4.加载数据
print('----------------Generate Datasets--------------')
x_train, y_train = generateds(train_path, train_txt)
x_test, y_test = generateds(test_path, test_txt)# 5.打乱数据集样本顺序
index = [i for i in range(len(x_train))]
np.random.shuffle(index)
x_train = x_train[index]
y_train = y_train[index]index1 = [j for j in range(len(x_test))]
np.random.shuffle(index1)
x_test = x_test[index1]
y_test = y_test[index1]# 6.数据增强
image_gen_train = ImageDataGenerator(rescale=1. / 1.,         # 如果是图像,分母为255,可以归一化到0-1rotation_range=45,       # 随机45度旋转width_shift_range=.15,   # 宽度偏移height_shift_range=.15,  # 高度偏移horizontal_flip=True,    # 水平翻转zoom_range=0.5           # 将图像随机缩放阈量50%
)
image_gen_train.fit(x_train)# 7.搭建网络
model = tf.keras.models.Sequential([Conv2D(filters=32, kernel_size=(3, 3)),BatchNormalization(),Activation('relu'),MaxPool2D(pool_size=(2, 2), strides=2),Conv2D(filters=64, kernel_size=(3, 3)),BatchNormalization(),Activation('relu'),MaxPool2D(pool_size=(2, 2), strides=2),Conv2D(filters=128, kernel_size=(3, 3)),BatchNormalization(),Activation('relu'),MaxPool2D(pool_size=(2, 2), strides=2),Flatten(),                      # 把输入特征拉直为一维数组数值Dense(128, activation='relu'),Dense(64, activation='relu'),Dense(2, activation='softmax')
])# 8.配置参数
model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])# 9.设置调用和保存模型
# 调用模型
checkpoint_save_path = "./checkpoint/cat_dag.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):print("--------------load model--------------")model.load_weights(checkpoint_save_path)
# 保存模型
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,save_weights_only=True,save_best_only=True)# 10.训练
history = model.fit(image_gen_train.flow(x_train, y_train, batch_size=32),epochs=15, validation_data=(x_test, y_test), validation_freq=1,callbacks=[cp_callback])
## 提取acc和loss
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']# 11.打印和保存网络参数
model.summary()
file = open('./weights.txt', 'w')
for v in model.trainable_variables:file.write(str(v.name) + '\n')file.write(str(v.shape) + '\n')file.write(str(v.numpy()) + '\n')
file.close()# 12.绘制acc和loss曲线
plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

训练的正确率在72%左右,loss和acc图如下所示:

3.输入图片测试

import tensorflow as tf
from tensorflow.keras.layers import Dense,Activation,Dropout,Conv2D,BatchNormalization,MaxPool2D,Flatten
import cv2# 设置GPU显存按需申请
gpu = tf.config.experimental.list_physical_devices(device_type='GPU')
assert len(gpu) == 1
tf.config.experimental.set_memory_growth(gpu[0], True)# 1.复现模型(前向传播)
model = tf.keras.models.Sequential([Conv2D(filters=32, kernel_size=(3, 3)),BatchNormalization(),Activation('relu'),MaxPool2D(pool_size=(2, 2), strides=2),Conv2D(filters=64, kernel_size=(3, 3)),BatchNormalization(),Activation('relu'),MaxPool2D(pool_size=(2, 2), strides=2),Conv2D(filters=128, kernel_size=(3, 3)),BatchNormalization(),Activation('relu'),MaxPool2D(pool_size=(2, 2), strides=2),Flatten(),                      # 把输入特征拉直为一维数组数值Dense(128, activation='relu'),Dense(64, activation='relu'),Dense(2, activation='softmax')
])# 2.加载参数
model_save_path = "./checkpoint/cat_dag.ckpt"       # 模型参数存储的路径
model.load_weights(model_save_path)# 3.数据预处理
img = cv2.imread("./train/dog.228.jpg")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)      # 转为RGB
img = cv2.resize(img, (224, 224))
img = img / 255.0       # 归一化
img = img[tf.newaxis, ...]# 4.预测结果
result = model.predict(img)
pred = tf.argmax(result, axis=1)
if pred.numpy() == 0:print('识别结果是:小猫')
else:print('识别结果是:小狗')

[tensorflow2笔记十] 搭建卷积网络实现猫狗图片分类相关推荐

  1. 使用预训练的卷积神经网络(猫狗图片分类)

    本次所用数据来自ImageNet,使用预训练好的数据来预测一个新的数据集:猫狗图片分类.这里,使用VGG模型,这个模型内置在Keras中,直接导入就可以了. from keras.applicatio ...

  2. 用卷积神经网络实现猫狗图片分类

    该例程使用数据集来源于 kaggle cat_VS _dog 数据集中的一部分, 用卷积神经网络实现猫狗图片二分类,例程序比较简单,就不多解释了,代码中会有相应的注释,直接上代码: import nu ...

  3. PyTorch搭建预训练AlexNet、DenseNet、ResNet、VGG实现猫狗图片分类

    目录 前言 AlexNet DensNet ResNet VGG 前言 在之前的文章中,利用一个简单的三层CNN猫狗图片分类,正确率不高,详见: CNN简单实战:PyTorch搭建CNN对猫狗图片进行 ...

  4. Top2:CNN 卷积神经网络实现猫狗图片识别二分类

    Top2:CNN 卷积神经网络实现猫狗图片识别二分类 系统:Windows10 Professional 环境:python=3.6 tensorflow-gpu=1.14 ```python &qu ...

  5. 11.CNN实现真实猫狗图片分类

    CNN实现真实猫狗图片分类 个人认为,和上一节的mnist数据集里面的手写数字图片不同之处就是,真实的图片更加复杂,像素点更多.因此在对应的图片预处理方面会稍微麻烦一些.但是这个例子能让我们可以处理自 ...

  6. 体验AI乐趣:基于AI Gallery的二分类猫狗图片分类小数据集自动学习

    摘要:直接使用AI Gallery里面现有的数据集进行自动学习训练,很简单和方便,节约时间,不用自己去训练了,AI Gallery 里面有很多类似的有趣数据集,也非常好玩,大家一起试试吧. 本文分享自 ...

  7. CNN之从头训练一个猫狗图片分类模型

    猫狗图片下载地址: 链接:https://pan.baidu.com/s/1l1AnBgkAAEhh0vI5_loWKw 提取码:2xq4 说明:大概有816M大小,分为train和test,trai ...

  8. [深度之眼]TensorFlow2.0项目班-猫狗图片分类

    猫狗数据集官网 猫示例: 狗示例: 训练集:猫狗各11500张图片 验证集:猫狗各1000张图片 难点:图片大小不统一,标签未配对 首先加载需要的包: import tensorflow as tf ...

  9. 【Keras】 计算机视觉 CNN 实现猫狗图片分类

    目录 综述 图像预览 数据预处理 验证集 模型训练 训练结果 综述 本项目旨在通过一个公开数据集,训练一个可以将图片中的猫和狗进行分类的模型. 数据集包括25,000 张训练数据.其中猫和狗的照片各 ...

最新文章

  1. [JavaScript] 日期时间戳的使用与计算
  2. SqlDataAdapter.Update批量数据更新
  3. context:annotation-config vs context:component-scan
  4. Python 包管理之 poetry
  5. c语言随机抽取小程序_C语言整人小程序,慎用,谨记!
  6. 计算机与工程建设项目结合,计算机科学与技术在工程建设项目管理中应用.doc...
  7. 好看的php验证码,一个漂亮的PHP验证码_PHP教程
  8. 在nodejs项目装一个库的多个版本
  9. OpenCV精进之路(零):core组件——绘制点、直线、几何图形
  10. bootstrap,layui,elementui vantui的区别
  11. 零基础学SQL(一、数据库与SQL简介)
  12. MEMS传感市场,美/日/德企占主导地位
  13. 用pytest实现POM模型
  14. element中file-upload组件的提示‘按delete键可删除’,怎么去掉
  15. 大数据到底怎么学:数据科学概论与大数据学习误区
  16. 最小公倍数用c语言,如何用C语言求最小公倍数。。。
  17. DxO PhotoLab 2.1.2 for Mac精华汉化版 DxO PhotoLab 2.1.2 for Mac中文版
  18. 数据同步工具之DataX实操
  19. 在BOSS直聘发现了一个前端小秘密
  20. 大数据常见错误解决方案(转载)

热门文章

  1. 将计算机引入课堂热潮 英语一,浅谈小学英语课堂的计算机辅助教学论文
  2. JDK1.8新增时间日期API
  3. 计算机房怎么读英语单词,教学机房,teaching computer lab,音标,读音,翻译,英文例句,英语词典...
  4. adnroid studio debug模式提示 Method breakpoints may dramatically slow down debugging
  5. android app技术亮点
  6. Codeforces Round #626 (Div. 2)
  7. R2S软路由+夸克网盘实现本地追剧
  8. 指数加权平均(EWA)
  9. 让我们拥抱DataV,感受数据可视化的魅力
  10. 第三章 似有隐情的谈话