一、结构分布

先介绍一下代码的结构分布吧

1、tain.py文件是训练的时候首先执行的文件,里面的函数有eval()评估函数,train()训练函数

2、trainer.py文件是网络的流图,关于如何forward,如何计算loss,如何反向计算,如何保存模型,如何控制权重更新等等,这个里面的函数会在train.py中的train()函数开始的时候调用,先构建fasterrcnn的网络,然后将网络作为参数传给trainer的构造函数

3、data文件夹,下面都是数据读取的方法

(1)dataset.py文件是批量加载数据,这个也会在train()函数开始初始化了一个dataset对象。

(2)voc_dataset.py文件是针对VOC数据集格式准备的,批量加载VOC数据集,解析XML文件都在这个类中,而且他是在dataset.py文件中调用,加载VOC数据集用的。

(3)util.py文件是定义了一些图像预处理工具,包括read_image,resize_box、crop_bbox这些会在其他py文件中调用,比如voc_dataset.py文件中调用了read_image数据进行读取数据。

(4)__init__.py文件好像是python要求类下必须有这个文件,需要确定一下???

4、model文件夹,下面都是网络构建的一些py。

(1)faster_rcnn_vgg16.py文件,是用来构建FasterRCNN-vgg16网络的,该网络分三部分创建,extractor特征提取网络,是利用torchvision.model模块创建的VGG16,然后是RPN网络创建,再就是ROIHeader网络创建,这个网络构建对象在train()函数开始的时候就创建了,用于先构建网络,然后再传入trainer。

(2)faster_rcnn.py文件,是一个base-class,faster-rcnn-vgg16类继承了这个类

(3)region_proposal_network.py文件,用于构建rpn模块,在faster_rcnn_vgg16.py文件中调用生成网络结构。

(4)roi_module.py文件,这个暂时没研究,在faster_rcnn_vgg16.py中调用了,初始化的时候ROIPooling。??

(5)utils文件夹,下面是一些工具,nms文件夹是非极大值抑制,其他的没仔细看,后期研究,主要是在faster_rcnn.py文件中调用了。

6、utils文件夹,这里面是工具

二、训练流程

1、首先调用train.py文件,输入相关参数进行训练。输入的控制台参数用**kwargs来表示,学习了python的控制台参数知道这是个接受字典形参数。https://www.cnblogs.com/zhangzhuozheng/p/8053045.html可参见这个地址有详细说明。

2、进行参数解析,利用了utils文件夹下的config.py文件进行了参数获取,这个文件中自定义一些默认参数,主要返回的是学习率学习策略,数据集地址等。

3、构造数据集对象,包括标签、图像名称列表。

4、根据batch_size,num_workers进行数据加载对象声明。数据是不是这个时候加载的还待定,感觉这个loader就像一个占位符,先占个坑,等运行网络的时候就开始读入了。具体DataLoader的用法需要查看pytorch

5、然后构建网络faster_rcnn_vgg16,构建方式见前面说的faster_rcnn_vgg16.py。

6、构建traner对象,将网络输入进行。

7、判断opt.load_path是否存在,这个load_path是在config.py中定义的,是model的地址,默认值是none,也就是如果在控制台输入没有指定--model这个参数,那么就没有了。如果有model即预训练模型,则调用trainer中的load函数加载预训练模型。

(1)trainer.load()函数解析,首先利用torch.load()函数加载模型,然后判断‘model’字符串是否存在,来判断是单纯加载参数还是加载带模型的参数(这是我个人理解的,具体要看pytorch的load_state_dict函数),最后判断参数是否修改,默认没改,最后判断优化器是否在加载的网络里,是的话加载预训练模型中的优化器。

8、可视化训练数据的label,调用trainer.vis.text函数,函数解析待会。

9、best_map参数干什么用的不知道待定,lr_是学习率获取。

10、下面就是循环训练啦,循环条件是epoch数,这个是opt超参数规定的。

(1)trainer.reset_meters()先重置界面上所有的数据,相当于一个epoch更新一次显示数据。

(2)开启一个for循环,枚举数据啦,从dataloader中按照batch-size循环读取数据,循环条件是把数据取完,tqdm模块是进度条模块具体可以百度。

(3)然后调用array_tool.py文件(在utils文件夹下)中的scalar()函数,传入参数是scale,这个参数是什么意思

(4)把数据传到cuda中,用来加速计算,返回转换后的cuda版本的数据,下面调用的都是cuda版的。

(5)利用trainer.train_step函数进行计算,前面介绍过这个函数,是用来更新一次权重的。

(6)图像进行归一化处理.

(7)然后是显示

(8)预测bboxes,label,这个predict函数是哪来的呢,首先是trainer,而trainer中调用的网络是fasterrcnnvgg16,这个网络继承 的是fasterrcnn的类,fasterrcnn类中有一个predict函数。

(9)下面就是一些可视化操作了,然后跳出了枚举数据的循环。这就是1个epoch完成了

(10)模型评估,利用测试集来做,前面已经加载了测试集,test_dataloader

(11)得到优化器中学习率的数值,并显示日志相关内容,包括lr,map,loss

(12)根据评测结果判断map是否是大于阈值best_map,如果是保存模型

(13)判断当前的epoch是否=9,如果是就加载最好的map和改变学习率

(14)判断如果epoch=13就跳出迭代循环???这个是这个实验里设计的具体原因不清楚。待定

FasterRCNN-pytorch的代码解析相关推荐

  1. Temporal Fusion Transformer (TFT) 各模块功能和代码解析(pytorch)

    Temporal Fusion Transformer (TFT) 各模块功能和代码解析(pytorch) 文章目录 Temporal Fusion Transformer (TFT) 各模块功能和代 ...

  2. pytorch代码解析:loss = y_hat - y.view(y_hat.size())

    pytorch代码解析:pytorch中loss = y_hat - y.view(y_hat.size()) import torchy_hat = torch.tensor([[-0.0044], ...

  3. Hugging Face实战(NLP实战/Transformer实战/预训练模型/分词器/模型微调/模型自动选择/PyTorch版本/代码逐行解析)下篇之模型训练

    模型训练的流程代码是不是特别特别多啊?有的童鞋看过Bert那个源码写的特别特别详细,参数贼多,运行一个模型百八十个参数的. Transformer对NLP的理解是一个大道至简的感觉,Hugging F ...

  4. 单目标跟踪算法:Siamese RPN论文解读和代码解析

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 作者:周威 | 来源:知乎 https://zhuanlan.zhihu.com/p/16198364 ...

  5. Faster-RCNN.pytorch的搭建、使用过程详解(适配PyTorch 1.0以上版本)

    Faster-RCNN.pytorch的搭建.使用过程详解 引言 faster-rcnn pytorch代码下载 faster-rcnn pytorch配置过程 faster-rcnn pytorch ...

  6. Positional Encodings in ViTs 近期各视觉Transformer中的位置编码方法总结及代码解析 1

    Positional Encodings in ViTs 近期各视觉Transformer中的位置编码方法总结及代码解析 最近CV领域的Vision Transformer将在NLP领域的Transo ...

  7. [GCN] 代码解析 of GitHub:Semi-supervised classification with graph convolutional networks

    本文解析的代码是论文Semi-Supervised Classification with Graph Convolutional Networks作者提供的实现代码. 原GitHub:Graph C ...

  8. 目标检测算法之常见评价指标的详细计算方法及代码解析

    前言 之前简单介绍过目标检测算法的一些评价标准,地址为目标检测算法之评价标准和常见数据集盘点.然而这篇文章仅仅只是从概念性的角度来阐述了常见的评价标准如Acc,Precision,Recall,AP等 ...

  9. YOLO系列 --- YOLOV7算法(二):YOLO V7算法detect.py代码解析

    YOLO系列 - YOLOV7算法(二):YOLO V7算法detect.py代码解析 parser = argparse.ArgumentParser()parser.add_argument('- ...

  10. YOLO-V5 算法和代码解析系列 —— 学习路线规划综述

    目录标题 为什么学习 YOLO-V5 ? 博客文章列表 面向对象 开源项目学习方法 预备知识 项目目录结构 为什么学习 YOLO-V5 ? 算法性能:与YOLO系列(V1,V2,V3,V4)相比,YO ...

最新文章

  1. 模型的可解释性:部分依赖图PDP和个体条件期望图ICE
  2. hal库开启中断关中断_「正点原子NANO STM32开发板资料连载」第十章 外部中断实验...
  3. excel中如何筛选重复数据
  4. 微信扫码支付官方配置(一)
  5. rest_framework05:GenericAPIView用法/扩展类5个/子类9个/ViewSetMixin 自定义方法名字
  6. CentOS 更改MySQL数据库目录位置
  7. 飞鸽传书绿色版 部分数据库被陆续公开了
  8. 【英语学习】【WOTD】animadversion 释义/词源/示例
  9. kali linux引导文件修复,Kali+Windows引导修复
  10. PTA—考试座位号(C语言)
  11. 如何让gitbook与github仓库关联
  12. 习题4.5 顺序存储的二叉树的最近的公共祖先问题 (25 分)
  13. PHP接入芝麻信用续。
  14. php文件显示代码行数,php统计文件中的代码行数
  15. 参数化CFAR的FPGA实现
  16. Java的BIO和NIO很难懂?用代码实践给你看,再不懂我转行!
  17. Unity3D 学习笔记 —— Tween对象的实现与动作管理
  18. pdo_mysql扩展库_MySQL数据库之PDO扩展
  19. 从输入url到页面返回到底发生了什么
  20. 【解决方案】快递代收点部署视频监控,EasyCVR视频融合平台来助力

热门文章

  1. eclipse中解决svn连接时数字证书问题
  2. 免费分享thinkphp框架开发周易八字起名网宝宝起名在线下单网站源码自适应可二开
  3. 关于outlook2019/2016发送到outlook2010/2013的约会/日历中图片不显示的问题
  4. 集成电路设计与集成系统和计算机科学与技术,2019年集成电路设计与集成系统本科专业怎么样?...
  5. 数字集成电路设计入门书籍
  6. 有效提升办公生产力,咪鼠语音智能鼠标,我目前使用过最好用的
  7. DELL笔记本安装Ubuntu 14.04
  8. 【全国大学英语四、六级考试(CET)成绩单补办】
  9. 小米id锁状态查询_Mysql中的三类锁,你知道吗?
  10. pythonocr训练模型_cnocr: cnocr是用来做中文OCR的Python 3包。cnocr自带了训练好的识别模型,安装后即可直接使用...