需要的第三方库:

pytorch、matplotlib、json、os、tqdm

一、model.py的编写

(1)准备工作

1.参照vgg网络结构图(如下图1),定义一个字典,用于存放各种vgg网络,字典如下图2(M表示最大池化层)


2.定义一个获取特征的函数,此处命名为make_features,参数为模型名字,再遍历字典中键对应的值列表,向layers中加入对应的卷积层和池化层,最后返回打包完成的feature(非关键字参数),用于后续操作

(2)VGG类的定义

创建一个VGG类,父类为nn.Module,初始化函数的参数中设置:feature(包含网络中卷积与池化各层,并用nn.Sequential打包完成)、num_classes(对应类别个数,即全连接层最后的节点个数,设为1000);
再在初始化函数中编写三个全连接层,并打包,如下图:
全连接层之间先使用relu函数激活,再使用dropout,使一半的神经元随机失活,防止过拟合

然后定义其正向传播过程,将features从维度为1展平,再放入classifier,然后返回值

初始化权重
利用for循环,对卷积和池化层分别进行权重初始化,并对偏置量bias置0

vgg函数的编写

定义函数名为vgg的函数,参数为model_name和可变长度参数,用于直接实例化VGG类对象,返回对应的model

二、train.py的编写

定义main函数,自动调用,完成训练与验证

验证是否使用Gpu:

(1)数据预处理

定义一个字典,存放对训练集和验证集的处理;
训练的图像经过随机裁剪,随机水平反转,数据转换为tensor格式,以及对数据进行标准化处理;验证集则需要强行变为224*224的图像,再转换为tensor格式,最后标准化


(2)相关数据集位置读取


以上一行代码可指定当前py文件上两级目录的绝对位置(以下图为例,train.py存放在VGG_ pytorch,则data_root即为projects文件夹的绝对位置)

根据数据集的存放位置,利用os.path.join拼接得到图片的路径,再分别对应训练和验证集进行打开数据集位置并进行预处理,再加载数据集,以下以训练集为例,测试集代码与之类似,不做展示,需要注意的是,num_workers在Windows平台一般不能置为非0数字,若是Linux等平台可进行修改

利用.class_to_idx方法生成以类别为键,数字为值的字典,再将其键值交换,写入json文件,效果如下:


(3)模型实例化

利用model.py中的vgg函数实例化一个net对象,使之成为训练所使用的model,损失函数选择交叉熵,优化器选择Adam,并将learning rate(学习率)设为0.0001(下图以vgg16为例,init_weights是model.py中定义的初始化权重),再定义存放权重文件的路径

利用tqdm对加载的训练集进行处理,再进行遍历,对其进行梯度置0,输出置于GPU,计算预测值与真实值的loss,再将loss反向传播置每一节点,最后根据loss更新参数(注意要先利用net.train()启用dropout);如下图:

对于验证集的处理有部分不同:
1.先使用.eval()禁用dropout,同时禁止计算损失梯度
2.循环中预测值输出为每行最大值,即可能性最高的预测值;
对预测值和真实值进行判断,若相等,则acc+1,不相等则不加

再用acc/验证集总个数得到验证集准确率,与之前迭代产生的验证集准确率作比较,在结尾处书写如下代码,则使得最终保存acc最高的权重数据

三、 predict.py的编写

首先对需要的数据进行如train.py中验证集一样的数据预处理,然后直接打开展示待检测图像,再为其添加batch维度(如下图)

然后读取之前写入的json文件,初始化网络,载入权重文件,完成网络模型的载入;

再禁用dropout、禁止计算损失梯度,将图片放入模型再压缩掉batch维度,最终得到输出,再通过softmax得到其概率分布,最后通过 torch.argmax()得到概率最大处的索引值,打印出结果

四、过程中的一些问题

1.运行train.py时,出现报错

OSError: [WinError 1455] 页面文件太小,无法完成操作。

解决方案:
https://blog.csdn.net/qq_17755303/article/details/112564030

2.运行train.py时,出现报错

RuntimeError: CUDA out of memory. Tried to allocate 50.00 MiB (GPU 0; 6.00 G

其原因为batch_size设置过大,初始值设置为32,出现上述报错,改为16之后,便正常运行

附上效果图:

关于用pytorch构建vgg网络实现花卉分类的学习笔记相关推荐

  1. 《基于张量网络的机器学习入门》学习笔记7

    <基于张量网络的机器学习入门>学习笔记7 量子算法 什么是量子算法 三个经典量子算法 Grover算法 背景 基本原理 例题 量子算法 什么是量子算法 例如我们求解一个问题,一个111千克 ...

  2. 《基于张量网络的机器学习入门》学习笔记6

    <基于张量网络的机器学习入门>学习笔记6 密度算符(密度矩阵) 具体到坐标表象 在纯态上 在混合态上 纯态下的密度算符 混合态下的密度算符 密度算符的性质 量子力学性质的密度算符描述 第一 ...

  3. 《基于张量网络的机器学习入门》学习笔记5

    <基于张量网络的机器学习入门>学习笔记5 量子概率体系 事件 互斥事件 概率与测量 不相容属性对 相容属性对 量子概率与经典概率的区别 量子测量 量子概率体系 我们将经典的实数概率扩展到复 ...

  4. 《基于张量网络的机器学习入门》学习笔记4

    <基于张量网络的机器学习入门>学习笔记4 量子概率 将概率复数化 分布与向量的表示 事件与Hilbert空间 不兼容属性及其复数概率表示 为什么一定要复数概率 量子概率 将概率复数化 在经 ...

  5. Java 3D编程实践_Java 3D编程实践——网络上的三维动画[学习笔记]

    评论 # re: Java 3D编程实践--网络上的三维动画[学习笔记] 2006-08-24 23:41 gy # re: Java 3D编程实践--网络上的三维动画[学习笔记] 2007-03-2 ...

  6. 《基于张量网络的机器学习入门》学习笔记8(Shor算法)

    <基于张量网络的机器学习入门>学习笔记8 Shor算法 来源 Shor算法的大致流程 因数分解 周期求取与量子傅里叶变换(QFT) Shor算法 来源 1994 1994 1994年,应用 ...

  7. pytorch 搭建 VGG 网络

    目录 1. VGG 网络介绍 2. 搭建VGG 网络 3. code 1. VGG 网络介绍 VGG16 的网络结构如图: VGG 网络是由卷积层和池化层构成基础的CNN 它的CONV卷积层的参数全部 ...

  8. pytorch实现VGG网络

    这里写目录标题 1. VGG网络结构 代码 1. VGG网络结构   VGG16相比AlexNet的一个改进是采用连续的几个3x3的卷积核代替AlexNet中的较大卷积核(11x11,7x7,5x5) ...

  9. 《基于张量网络的机器学习入门》学习笔记2

    <基于张量网络的学习入门>学习笔记2 量子逻辑门 单量子逻辑门 恒等操作 泡利-X门(Pauli-X gate) 泡利-Y门(Pauli-Y gate) 泡利-Z门(Pauli-Z gat ...

最新文章

  1. Spring Cloud Alibaba 之 RPC 消息:Dubbo 与 Nacos 体系如何协同作业
  2. 【必看】Linux 系统的备份恢复
  3. jQuery使用详解
  4. 青海师范大学云上健身计算机学院,尹君-欢迎光临青海师范大学计算机学院
  5. mybatis学习(3):映射文件的配置和接口创建
  6. SQL 死锁分析(转贴)
  7. Mondrian and OLAP
  8. Elasticsearch 读时分词、写时分词
  9. 03. 二维数组中的查找(C++实现)
  10. 让驰骋工作流程引擎 ccbpm使用自定义表单来实现自己的业务逻辑.
  11. 多维度积分管理系统java_Java毕业设计——超市积分管理系统项目设计
  12. C语言画爱心代码分析
  13. 如何使用 Web Speech API 在浏览器中识别语音
  14. 读王小波先生的《黄金时代》、《青铜时代》
  15. vr手柄设置_最佳无线VR设置,最新和即将推出
  16. 一页纸项目管理pdf_项目管理,一页纸就够了
  17. H5 实现横向滚动的方法及需要注意的地方
  18. 扫描版模糊pdf优化方法
  19. 杂七杂八的网络安全知识
  20. iOS---学习研究大牛Git高星项目YYCategories(二)

热门文章

  1. 北京突然宣布,元宇宙重大消息
  2. android炫彩文字和滚动的彩色背景
  3. MFC:CCheckListBox使用教程
  4. 民安智库开展老人体检消费者调查
  5. c语言debug小窗口怎么移动,大家指点下VS中调试的监视、内存窗口的技巧
  6. pygame网络游戏_5_4:网络编程_设计通信协议
  7. 近代数学史上的最大冤案
  8. Mixly(米思齐)的安装以及基于Arduino开发板实现电容触摸控制灯
  9. 复习笔记(函数的极值)
  10. 双鉴探测器是哪两种探测方式结合_双鉴红外探测器型号有哪些?