MMSegmentation 训练测试全流程
MMSegmentation 训练测试全流程
- 1.按照执行顺序的流程梳理
- Level 0: 运行 Shell 命令:
- Level 1: 在 tools/train.py 内:
- Level 2: 转进到 mmseg.apis 模块的 train_segmentor 函数内:
- Level 3: 转进到 mmcv/runner/iter_based_runner.py 内的 IterBasedRunner 类的 run 函数内部:
- Level 4: 转进到 IterBasedRunner 类的 train 函数内部
- Level 5: 转进到 EvalHook 类实例的 after_train_iter 函数内部:
- 4.函数说明:
- 5.疑问解答
- 参考链接:
括号的部分可以不看!是debug经过的内容,有些事调用了mmcv库的函数,只想看看流程不需要细看!
1.按照执行顺序的流程梳理
Level 0: 运行 Shell 命令:
python tools/train.py ${CONFIG_FILE [optional arguments]
Level 1: 在 tools/train.py 内:
- 读取各种
config: cfg = Config.fromfile(args.config)
- 创建
model: model = build_segmentor(cfg.model, train_cfg, test_cfg)
- 创建
training dataset: datasets = [build_dataset(cfg.data.train)]()
- 通过
Config
类的__getattr__
函数:value = super(ConfigDict, self).__getattr_
获取数据和数据增强信息并返回value - 转到
mmseg/datasets/builder.py
内的build_dataset函数,获取dataset
:dataset = build_from_cfg(cfg, DATASETS, default_args)
- 转到
/usr/local/lib/python3.8/dist-packages/mmcv/utils/registry.py
内的build_from_cfg
函数:args = cfg.copy()
- 获取数据格式类型:
obj_type = args.pop('type')
,比如obj_type
:ADE20KDatase
- 通过数据格式obj_type获得类
obj_cls = registry.get(obj_type)
,比如<class 'mmseg.datasets.ade.ADE20KDataset'>
- 获取return obj_cls(**args)
- (转到
/usr/lib/python3.8/typing.py
的Generic
类的__new__
函数:obj = super().__new__(cls)
) - 转到
mmseg/datasets/ade.py
中的ADE20KDataset
类的__init__
函数:super(ADE20KDataset, self).__init__(**
- 转到
mmseg/datasets/custom.py
中ADE20KDatase
类继承的CustomDataset
类- 调用loading.py中LoadAnnotations类进行初始化,获得image和mask的地址等信息,并获取image和mask名字的dict:
self.img_infos = self.load_annotations(self.img_dir, self.img_suffix,self.ann_dir,self.seg_map_suffix, self.split)
- 实例对象做运算时,就会调用
CustomDataset
类中的__getitem__()__
:self.prepare_train_img(idx)
- 调用prepare_train_img函数:self.pipeline(results),调用
mmseg/datasets/pipelines/loading.py
的LoadImageFromFile
类和其他数据增强
- 调用loading.py中LoadAnnotations类进行初始化,获得image和mask的地址等信息,并获取image和mask名字的dict:
- 通过
- 创建
validation dataset: datasets.append(build_dataset(val_dataset))
- 将
model
,data
,config
喂给训练函数:train_segmentor(model, datasets, cfg)
Level 2: 转进到 mmseg.apis 模块的 train_segmentor 函数内:
- 创建
dataloader
:data_loaders = [build]()_dataloader(dataset, config)]
- 将
model
搬到 GPU 上去:model = MMDataParallel(model.cuda(), cfg)
- 创建
optimizer
:optimizer = build_optimizer(model, cfg)
- 创建
runner
:runner = build_runner(model, cfg, optimizer)
- 给 runner 注册
training hooks
:runner.register_training_hooks(cfg)
- 给 runner 注册
validation hooks
:runner.register_hook(eval_hook(val_dataloader, eval_cfg))
- 这个 eval_hook 是 EvalHook 类实例, 其重写了
after_train_iter
和after_train_epoch
两个方法, 在IterBasedRunner
中用的是after_train_iter
。
- 这个 eval_hook 是 EvalHook 类实例, 其重写了
- 开始训练
runner.run(data_loaders, cfg.workflow)
Level 3: 转进到 mmcv/runner/iter_based_runner.py 内的 IterBasedRunner 类的 run 函数内部:
Training
模式,mode = 'train', i = 0
, 运行iter_runner(iter_loaders[i](), **kwargs)
- 实质上是在运行
IterBasedRunner
类的train
函数:train(iter_loaders[0](), **kwargs)
- 从
while self.iter < self._max_iters:
可以看到, 这个train
函数一共会被调用self._max_iters
次 - 从中也可以看到这个
train
函数其实只负责做一个batch
数据的forward
计算
- 实质上是在运行
Validation
模式, 此处其实没有运行- mmseg 的所有 setting 都是
workflow = [('train', 1)]
- 实际上的 validation 是通过在
after_train_epoch
节点调用EvalHook
对象的after_train_iter
方法实现的。
- mmseg 的所有 setting 都是
Level 4: 转进到 IterBasedRunner 类的 train 函数内部
- 读取一个 batch 的数据:
data_batch = next(data_loader)
- 调用 model 的
train_step
函数计算loss
:outputs = self.model.train_step(data_batch)
- 尝试选择性进行
validation
:self.call_hook('after_train_iter')
- 实质上是调用
EvalHook
类实例的after_train_iter
函数;
- 实质上是调用
Level 5: 转进到 EvalHook 类实例的 after_train_iter 函数内部:
- 如果当前迭代数不能够被 interval 整除, 就不做 validation:
if not self.every_n_iters(runner, self.interval): return
- 如果能被整除, 计算一下
validation set
上的结果:results = single_gpu_test(model, dataloader)
- 这一步就是 enumerate 一下
data_loader
, 对于每个 batch 都用model forward
一下, 把 result 都 append 起来得到一个list results
, 就不再展开了
- 这一步就是 enumerate 一下
- 对于分割结果再调用
dataset
的evaluate
函数计算一下mIoU
,mDice
,mFscore
等metric
数值- 其实就是通过调用下
mmseg.core
里面的eval_metrics
函数调用total_intersect_and_union
函数计算下上述数值
- 其实就是通过调用下
4.函数说明:
self.pipeline = Compose(pipeline)
Compose
:把函数组合起来,每个函数的返回值是下一个函数的参数
print_log(f’Loaded {len(img_infos)} images’, logger=get_root_logger())
print_log
:打印日志
target = torch.where(target == ignore_index, target.new_tensor(0), target)
torch.where
:查找 target 中值为ignore_index(255)的值转为0,new_tensor
:target.new_tensor是将target的值copy一份,不共享内存,new_tensor(0)指值为0同样size矩阵
5.疑问解答
- CustomDataset类中pre_eval函数的ignore_index=255是起什么作用的? 是不计算255的loss吗
- 在
mmseg/core/evaluation/metrics.py
函数中找到了答案 intersect_and_union
函数中计算IOU
的时候,将ignore_index=255
的值忽略掉:mask = (label != ignore_index)
,相当于不计算背景的准确率,获取到的相当于是召回率Recall
。- 需要注意的是,其中
reduce_zero_label=True
时,是将像素值为0的转为255:label[label == 0] = 255
,会在mask = (label != ignore_index)
处一并忽略
- 在
- 注意:255值在标注员工标注过程中代表不需要标注的区域,相当于背景,在需要标注的区域,背景值是0
参考链接:
【1】MMSegmentation 训练测试全流程及其关键节点
MMSegmentation 训练测试全流程相关推荐
- 记录一次Monkey测试全流程
记录一次Monkey测试全流程 1.检查设备连接 ZHR:~ zc$ adb devices List of devices attached JPF4C19123011893 device 2.查看 ...
- 基于Jenkins的开发测试全流程持续集成实践
今年上半年一直在公司实践CI,本文将上半年来的一些实践总结一下,可能不太完善或优美,但的确初步解决了我目前所在项目组的一些痛点.当然这仅是一家之言也不够完整,后续下半年还会深入实践和引入Kuberne ...
- Web网页测试全流程解析论Web自动化测试
1.功能测试 web网页测试中的功能测试,主要测试网页中的所有链接.数据库连接.用于在网页中提交或获取用户信息的表单.Cookie 测试等. (1)查看所有链接: ·测试从所有页面到被测特定域的传出链 ...
- (最全干货分享)渗透测试全流程归纳总结之二
进来先点个赞,评个论,关个注呗- 获取更多学习资料.想加入社群.深入学习,请扫我的二维码或加Memory20000427. 2.OSINT 公开情报收集 2.1社工技巧 查看注册的网站:0xreg r ...
- 【深度学习】深度学习模型训练全流程!
Datawhale干货 作者:黄星源.奉现,Datawhale优秀学习者 本文从构建数据验证集.模型训练.模型加载和模型调参四个部分对深度学习中模型训练的全流程进行讲解. 一个成熟合格的深度学习训练流 ...
- 加载tf模型 正确率很低_深度学习模型训练全流程!
↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:黄星源.奉现,Datawhale优秀学习者 本文从构建数据验证集. ...
- 深度学习模型训练全流程!
↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:黄星源.奉现,Datawhale优秀学习者 本文从构建数据验证集. ...
- 网易云容器服务微服务化实践—微服务测试及镜像化提测全流程实践
前言 近几年,互联网项目很多都有从单体服务转变成微服务化的趋势,尤其是一些架构复杂,业务比较广泛的项目,微服务化是大势所趋,可以解决独立构建.更新.运维等一系列问题,从而解放生产力,促进交付效率和质量 ...
- 渗透测试 ( 1 ) --- 相关术语、必备 工具、导航、全流程总结、入侵网站思路
From:https://zhuanlan.zhihu.com/p/401413938 渗透测试实战教学:https://www.zhihu.com/column/c_1334810805263515 ...
最新文章
- (转载)java工程师15本必读书籍推荐
- RHEAS 显示、输入中文
- C# 中的 in 参数和性能分析
- MyBatis中in的使用
- java.lang.IllegalArgumentException: MALFORMED jar解析中文报错问题
- java log4j 热部署_JAVA类加载器分析--热部署的缺陷(有代码示例,及分析)
- 这个世界是那样的似曾相识
- [原创] 若水新闻安卓客户端开发教程笔记
- 平面设计中有趣的词云图如何设计
- QT实现经纬度转换为图片像素坐标
- 抖音文案、声音、设计、视频、图片素材网站
- ai如何复制文字并对齐_AI文字怎么对齐? ai文字排版的方法
- nginx gzip
- Options error: In [CMD-LINE]:1: Error opening configuration file: xxxx.ovpn
- 加拿大滑铁卢大学计算机世界排名,滑铁卢大学世界排名
- 互联网盈利模式研习笔记 1:流量变现
- TypeScript进阶 之 重难点梳理
- 路由设置代理ip的作用
- ErrorCannot find module XXX 解决方法
- 电力监控组态软件FCPower下载,力控最新组态软件下载!