一、序言

使用百度飞浆提供的paddle框架实现蝴蝶分类,环境:paddle 2.0.2,opencv 4.5.4.58,pycharm编译器。

目录结构:

  • Butterfly20里有20个文件夹,分别代表20种蝴蝶种类,每个文件夹内有多个同种类的蝴蝶照片
  • Butterfly20_test里有200张蝴蝶照片用于测试训练好的网络
  • visualdl_log里存放训练好的网络,使用log文件格式
  • species.txt里存放20种类别的名称和序号
  • train_set和validation_set在运行时随机分配

二、准备数据

随机查看一个蝴蝶图片及其类别

data_path= '.\Butterfly20\*\*.jpg'
but_files =glob.glob(data_path) #获取Butterfly20中的所有图片地址print('图片数据为',len(but_files))#随机显示一个样品的图片
index=random.choice(but_files)  # 随机获取一个图片
print(index)  # 查看地址name=index.split('\\')[-2]  # 获取标签,得到的是训练集中随机蝴蝶的类别
img = Image.open(index)  # 打开图片
img = cv2.imread(index)  # 图片处理
print(img.shape)  # 输出图片形状(441,600,3)
img = img[:,:,::-1] # 三通道,-1表示从右往左切片,opencv输入为BGR,故从右往左切片为RGB三通道
print(f'该样本标签为:{name}')
cv2.imshow("ran_img",img)
cv2.waitKey(0)

测试输出为:

写一个Reader类,其中定义三个函数,分别为初始化、处理图像、计算长度,使用Reader类加载训练集与数据集

### 查看数据类型
data_list = [] #用个列表保存每个样本的读取路径、标签
# 由于属种名称本身是字符串,而输入模型的是数字。需要构造一个字典,把某个数字代表该属种名称。键是属种名称,值是整数。
label_list=[]
with open("E:/Pycharm/workspace/OpenCV/butterfly/species.txt") as f:for line in f:a,b = line.strip("\n").split(" ") #a为1-20的序号,b为每个种类的namelabel_list.append([b, int(a)-1]) #将20种txt种的类别加入label_list数组种
label_dic = dict(label_list) #dict创建一个字典,字典中有20种蝴蝶类型butterfly_path = './Butterfly20/'
#若项目目录内已经有train_set与validation_set两个数据集,则删除,之后重新创建这两个数据集
if(os.path.exists('E:/Pycharm/workspace/OpenCV/butterfly/train_set.txt')):  # 判断有误文件os.remove('E:/Pycharm/workspace/OpenCV/butterfly/train_set.txt')  # 删除文件
if(os.path.exists('E:/Pycharm/workspace/OpenCV/butterfly/validation_set.txt')):os.remove('E:/Pycharm/workspace/OpenCV/butterfly/validation_set.txt')for i in os.listdir(butterfly_path): #得到Butterfly20里的所有文件夹if i not in '.DS_Store': #DB_Store里是20种蝴蝶类型的名字for j in os.listdir(os.path.join(butterfly_path, i)): #路径拼接,拼接后为./Butterfly20/20种名字,j从这个路径里提取序号.jpgdata_list.append(f'{os.path.join(butterfly_path, i, j)}\t{label_dic[i]}\n') #前一个大括号是每个图片具体路径,后一个是其种类的序号random.shuffle(data_list)  # 乱序
print(data_list[0]) #打印随机选出的第一个图片以及其属于的种类号
data_len = len(data_list)
count = 0for data in data_list:if count <= data_len*0.8:with open('E:/Pycharm/workspace/OpenCV/butterfly/train_set.txt', 'a')as f: # 80%写入训练集f.write(data)count += 1else:with open('E:/Pycharm/workspace/OpenCV/butterfly/validation_set.txt', 'a')as tf:  # 20%写入验证集tf.write(data)count += 1# 自定义数据读取器
class Reader(Dataset):def __init__(self, mode='train_set'):"""初始化函数"""self.data = []with open(f'{mode}_set.txt') as f: #train_set或validation_setfor line in f.readlines():info = line.strip().split('\t') #strip函数去掉首部等于参数值的字符,无参数表示删掉换行符if len(info) > 0:self.data.append([info[0].strip(), info[1].strip()])def __getitem__(self, index): #将图片转换为(224,224)像素大小"""读取图片,对图片进行归一化处理,返回图片和 标签"""image_file, label = self.data[index]  # 获取数据img = Image.open(image_file)  # 读取图片img = img.convert('RGB')img = img.resize((224, 224), Image.ANTIALIAS)  # 图片大小样式归一化img = np.array(img).astype('float32')  # 转换成数组类型浮点型32位img = img.transpose((2, 0, 1))  # 读出来的图像是rgb,rgb,rbg..., 转置为 rrr...,ggg...,bbb...img = img / 255.0  # 数据缩放到0-1的范围return img, np.array(label, dtype='int64')def __len__(self):"""获取样本总数"""return len(self.data)#调用Reader类,其中三个函数都会走
# 训练的数据提供器
train_dataset = Reader(mode='train')
# 测试的数据提供器
eval_dataset = Reader(mode='validation')# 查看训练和测试数据的大小
print('train大小:', train_dataset.__len__())
print('eval大小:', eval_dataset.__len__())# 随机查看图片数据、大小及标签
for data, label in eval_dataset:print(data)print(np.array(data).shape) #(3,224,224)print(label)break #只循环一次即可

三、构建网络

使用paddle框架构造神经网络,选用resnet152网络用于图像分类,最后分为20类

import paddle.nn.functional as F
#定义模型
class MyNet(paddle.nn.Layer):def __init__(self):super(MyNet,self).__init__()self.layer=paddle.vision.models.resnet152(pretrained=True) #152层的resnet模型,预训练模型只需要设定模型参数pretained=Trueself.dropout=paddle.nn.Dropout(p=0.5) #Dropout值设为0.5,self.fc1 = paddle.nn.Linear(1000, 512) #fc为全连接层,与模型训练后为1000个输出,要最后分20类self.fc2 = paddle.nn.Linear(512, 20) #两个全连接层实现1000-20#网络的前向计算过程def forward(self,x):x=self.layer(x) #resnet152模型x=self.dropout(x) #值为0.5的Dropoutx=self.fc1(x) #第一个全连接层x=F.relu(x) #使用relu函数激活x=self.fc2(x) #第二个全连接层得到20个分类特征return x

resnet网络结构如下:



















四、训练网络

用构建好的resnet152网络进行训练

model = paddle.Model(MyNet())
model.summary((1, 3, 224, 224)) #输出各层参数input_define = paddle.static.InputSpec(shape=[-1,3,224,224], dtype="float32", name="img")
label_define = paddle.static.InputSpec(shape=[-1,1], dtype="int64", name="label")#实例化网络对象并定义优化器等训练逻辑
model = MyNet()
model = paddle.Model(model,inputs=input_define,labels=label_define) #用Paddle.Model()对模型进行封装
optimizer = paddle.optimizer.Adam(learning_rate=0.0001, parameters=model.parameters())
#上述优化器中的学习率(learning_rate)参数很重要。要是训练过程中得到的准确率呈震荡状态,忽大忽小,可以试试进一步把学习率调低。model.prepare(optimizer=optimizer, #指定优化器loss=paddle.nn.CrossEntropyLoss(), #指定损失函数metrics=paddle.metric.Accuracy()) #指定评估方法callback = paddle.callbacks.VisualDL(log_dir='./visualdl_log')model.fit(train_data=train_dataset,     #训练数据集eval_data=eval_dataset,         #测试数据集batch_size=64,                  #一个批次的样本数量epochs=100,                      #迭代轮次save_dir="./visualdl_log", #把模型参数、优化器参数保存至自定义的文件夹save_freq=20,                    #设定每隔多少个epoch保存模型参数及优化器参数log_freq=100,                     #打印日志的频率verbose=1,                        # 日志展示模式shuffle=True,                     # 是否打乱数据集顺序callbacks=callback                # 回调函数使用)result = model.evaluate(eval_dataset, verbose=1)
print(result)model.save('E:/Pycharm/workspace/OpenCV/butterfly/butterfly_model')  # 保存模型

五、预测图片

随机使用一张图片,通过训练好的网络进行预测蝴蝶的种类,该蝴蝶属于第15类

def load_image(file): #加载测试图片并处理图片# 打开图片im = Image.open(file)# 将图片调整为跟训练数据一样的大小im = im.convert('RGB')im = im.resize((224, 224), Image.ANTIALIAS)# 建立图片矩阵 类型为float32im = np.array(im).astype(np.float32)# 矩阵转置im = im.transpose((2, 0, 1))# 将像素值从[0-255]转换为[0-1]im = im / 255.0# print(im)im = np.expand_dims(im, axis=0)# 保持和之前输入image维度一致print('im_shape的维度:', im.shape)return imfrom PIL import Image
# site = 255  # 读取图片位置
model_state_dict = paddle.load('E:/Pycharm/workspace/OpenCV/butterfly/butterfly_model.pdparams')  # 读取模型
model = MyNet()  # 实例化模型
model.set_state_dict(model_state_dict) #浅拷贝,读取模型
model.eval() #不进行BN与dropout,使用所有全职计算img = load_image(index)print(paddle.to_tensor(img).shape)
# print(paddle.reshape(paddle.to_tensor(img), (1, 3, 224, 224)))
ceshi = model(paddle.reshape(paddle.to_tensor(img), (1, 3, 224, 224)))  # 测试
print('预测的结果为:', list(label_dic.keys())[np.argmax(ceshi.numpy())])  # 获取值
with open("./work/result.txt", "w") as f:for r in result:f.write("{}\n".format(r))
Image.open(index)  # 显示图片

预测结果:

基于paddlepaddle构建resnet神经网络的蝴蝶分类相关推荐

  1. 基于PaddlePaddle构建ResNet18残差神经网络的食物图片分类问题

    基于PaddlePaddle构建ResNet18残差神经网络的食物图片分类问题 Introduction 本项目是在李宏毅机器学习课程的作业3进行的工作,任务是手动搭建一个CNN模型进行食物图片分类( ...

  2. 基于2维卷积神经网络的心电图分类

    在这里给大家分享一篇关于用深度学习进行心电图识别的论文,原文地址https://arxiv.org/abs/1804.06812,我翻译成了中文以便大家快速学习,中间难免有疏忽遗漏的地方,请大家谅解. ...

  3. 基于Pytorch全连接神经网络实现多分类

    (一)计算机视觉工具包的介绍 为了方便开发者应用,PyTorch专门开发了一个视觉工具包torchvision,主要包含以下三个部分: 1.models models提供了深度学习中各种经典的神经网络 ...

  4. 【人工智能 卷积神经网络】基础练习:基于torch构建卷积神经网络,测试集正确率达 百分之99

    声明:仅学习使用~ 这是一个关于卷积神经网络CNN的基础练习,也算是一个回顾.包含分解步骤,内容整合 以及最后的整体输出. 目录 一.步骤分解 1.0 系统环境.主要模块版本 1.1 相关模块的导入 ...

  5. 【项目实战】Python基于librosa和人工神经网络实现语音识别分类模型(ANN算法)项目实战

    说明:这是一个机器学习实战项目(附带数据+代码+文档+视频讲解),如需数据+代码+文档+视频讲解可以直接到文章最后获取. 1.项目背景 语音识别发展到现在作为人机交互的重要接口已经在很多方面改变了我们 ...

  6. python构建bp神经网络_鸢尾花分类(一个隐藏层)__1.数据集

    IDE:jupyter 目前我知道的数据集来源有两个,一个是csv数据集文件另一个是从sklearn.datasets导入 1.1 csv格式的数据集(下载地址已上传到博客园----数据集.rar) ...

  7. 基于pytorch搭建ResNet神经网络用于花类识别

  8. Python基于PyTorch实现BP神经网络ANN分类模型项目实战

    说明:这是一个机器学习实战项目(附带数据+代码+文档+视频讲解),如需数据+代码+文档+视频讲解可以直接到文章最后获取. 1.项目背景 在人工神经网络的发展历史上,感知机(Multilayer Per ...

  9. 【Pytorch(七)】基于 PyTorch 实现残差神经网络 ResNet

    基于 PyTorch 实现残差神经网络 ResNet 文章目录 基于 PyTorch 实现残差神经网络 ResNet 0. 概述 1. 数据集介绍 1.1 数据集准备 1.2 分析分类难度:CIFAR ...

最新文章

  1. Image deformation of AffineSimilarityRigidProjective
  2. 解决Vue用v-html、v-text渲染后台富文本框文本内容样式修改问题,用自定义css样式无法渲染出对应效果的问题
  3. JVM常用启动参数大全(附带解释)
  4. [Umbraco] 创建第一个页面
  5. when is extension component's resource bundle loaded
  6. neo4j安装_neo4j 社区版win10 下安装
  7. 请问运行py文件的时候怎么样可以不让那个黑框一闪...
  8. java实验1机动车实验目的_《Java程序设计》实验指导书.doc
  9. 一人网站所有的 ip地址_咸宁网站建设-网站的主要特征
  10. 小白学习使用gitee问题产生汇总(持续更新)
  11. Java ==和Equals方法的比较
  12. 善用GOOGLE–从入门到精通
  13. mount的挂载远程服务器文件夹
  14. H5中段落自动空两格
  15. 喜欢花,喜欢海,喜欢日出和日落
  16. Uva 12325 Zombie's Treasure Chest (贪心,分类讨论)
  17. C语言——运算符与表达式
  18. char与varchar的区别?
  19. 中国十大SNS交友网站排名
  20. 补偿 100 万? Oracle 裁员 900 人!

热门文章

  1. html让gif图片暂停,控制GIF动画暂停播放的代码
  2. 跟我学springboot(二十五)springboot-过滤器之拦截不需要走过滤器的链接使用方法
  3. 【Pandas-1】十分钟入门Pandas (上)
  4. 小学生台灯哪个品牌更护眼?学习专用的护眼台灯品牌
  5. 如何查看windows版本
  6. js动态添加带圆圈序号列表
  7. 西电研一人工智能复习随笔
  8. 美团运维面试官没想到jenkins我用得这么溜,人直接傻掉
  9. matlab中的sym
  10. JVM的内存区域划分(jdk7和jdk8)