1. 预训练模型

关于预训练模型,一般的检测都是使用ImageNet预训练的backbone,这是基本配置,官方也支持这种加载方式。

高级一点的的就是针对数据集做一次预训练:即将所有的目标裁剪出来,然后训练一个不错的分类模型,这样的初始化相比ImageNet就要好很多。

最后就是使用coco预训练的完整检测模型权重,这样的效果就是模型收敛速度快,而且效果一般都比较好,也是大家最常用的方法。由于每个任务的类别不同,需要对权重进行微调,这里给出mmdetection修改coco预训练权重类别的脚本。

脚本以cascade rcnn为例,其他模型的修改与之类似。

# for cascade rcnn
import torch
num_classes = 21
model_coco = torch.load("cascade_rcnn_x101_32x4d_fpn_2x_20181218-28f73c4c.pth")# weight
model_coco["state_dict"]["bbox_head.0.fc_cls.weight"].resize_(num_classes,1024)
model_coco["state_dict"]["bbox_head.1.fc_cls.weight"].resize_(num_classes,1024)
model_coco["state_dict"]["bbox_head.2.fc_cls.weight"].resize_(num_classes,1024)
# bias
model_coco["state_dict"]["bbox_head.0.fc_cls.bias"].resize_(num_classes)
model_coco["state_dict"]["bbox_head.1.fc_cls.bias"].resize_(num_classes)
model_coco["state_dict"]["bbox_head.2.fc_cls.bias"].resize_(num_classes)
#save new model
torch.save(model_coco,"coco_pretrained_weights_classes_%d.pth"%num_classes)

2. Soft-NMS

Soft-NMS改进了之前比较暴力的NMS,当IOU超过某个阈值后,不再直接删除该框,而是降低它的置信度(得分),如果得分低到一个阈值,就会被排除;但是如果降低后任然较高,就会保留。

在mmdetection中的设置如下:

test_cfg = dict(rpn=dict(nms_across_levels=False,nms_pre=1000,nms_post=1000,max_num=1000,nms_thr=0.7,min_bbox_size=0),rcnn=dict(score_thr=0.05, nms=dict(type='soft_nms', iou_thr=0.5), max_per_img=100),keep_all_stages=False)

3. GIoULoss

一般情况下,用GIoULoss代替L1Loss后会涨点。

原版用的配置文件(使用L1Loss)如下:

    rpn_head=dict(type='RPNHead',in_channels=256,feat_channels=256,anchor_generator=dict(type='AnchorGenerator',scales=[8],ratios=[0.5, 1.0, 2.0],strides=[4, 8, 16, 32, 64]),bbox_coder=dict(type='DeltaXYWHBBoxCoder',target_means=[0.0, 0.0, 0.0, 0.0],target_stds=[1.0, 1.0, 1.0, 1.0]),loss_cls=dict(type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),loss_bbox=dict(type='L1Loss', loss_weight=1.0)),roi_head=dict(type='StandardRoIHead',bbox_roi_extractor=dict(type='SingleRoIExtractor',roi_layer=dict(type='RoIAlign', out_size=7, sample_num=0),out_channels=256,featmap_strides=[4, 8, 16, 32]),bbox_head=dict(type='Shared2FCBBoxHead',in_channels=256,fc_out_channels=1024,roi_feat_size=7,num_classes=10,bbox_coder=dict(type='DeltaXYWHBBoxCoder',target_means=[0.0, 0.0, 0.0, 0.0],target_stds=[0.1, 0.1, 0.2, 0.2]),reg_class_agnostic=False,loss_cls=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),loss_bbox=dict(type='L1Loss', loss_weight=1.0))))

添加GIoULoss后的配置文件如下:

    rpn_head=dict(type='RPNHead',in_channels=256,feat_channels=256,anchor_generator=dict(type='AnchorGenerator',scales=[8],ratios=[0.5, 1.0, 2.0],strides=[4, 8, 16, 32, 64]),bbox_coder=dict(type='DeltaXYWHBBoxCoder',target_means=[0.0, 0.0, 0.0, 0.0],target_stds=[1.0, 1.0, 1.0, 1.0]),loss_cls=dict(type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),reg_decoded_bbox=True,      # 使用GIoUI时注意添加loss_bbox=dict(type='GIoULoss', loss_weight=5.0)),roi_head=dict(type='StandardRoIHead',bbox_roi_extractor=dict(type='SingleRoIExtractor',roi_layer=dict(type='RoIAlign', out_size=7, sample_num=0),out_channels=256,featmap_strides=[4, 8, 16, 32]),bbox_head=dict(type='Shared2FCBBoxHead',in_channels=256,fc_out_channels=1024,roi_feat_size=7,num_classes=10,bbox_coder=dict(type='DeltaXYWHBBoxCoder',target_means=[0.0, 0.0, 0.0, 0.0],target_stds=[0.1, 0.1, 0.2, 0.2]),reg_class_agnostic=False,loss_cls=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),reg_decoded_bbox=True,     # 使用GIoUI时注意添加loss_bbox=dict(type='GIoULoss', loss_weight=5.0))))

4. 模型瘦身小技巧

mmdetection在保存模型时,除了保存权重,还保存了原始数据和优化参数。但是,模型在测试时,有些参数是没有用的,怎样去掉这些无用的参数使模型减小(大约减小50%)呢?见下面的代码:

import torchmodel_path = "epoch_30.pth"
checkpoint = torch.load(model_path)
checkpoint['meta'] = None
checkpoint['optimizer'] = Noneweights = checkpoint['state_dict']state_dict = {"state_dict":weights}torch.save(state_dict,  './epotch_30_new.pth')

5. 在线难例挖掘(OHEM)

在线难例挖掘:在训练过程中在线的选择困难样本进行训练(选择loss较大的样本)。

思想比较简单,在mmdetection中的应用如下:

以faster rcnn为例子:

_base_ = './faster_rcnn_r50_fpn_1x_coco.py'
train_cfg = dict(rcnn=dict(sampler=dict(type='OHEMSampler')))

第一行为你训练模型的配置文件,第二行把采样方式设置为在线难例挖掘。

todo:

(1). GIoULoss  已经完成

(2). 在线难例挖掘  已经完成

(3). 混合精度训练

(4). 可变形卷积

(5). 多尺度训练

(6). 多尺度测试与数据增强测试

(7). Albu数据增强库的使用

(8). 模型融合

(9). 过分割测试

(10). mosaic数据增强

(11). PAFPN

(12). 样本均衡抑制长尾分布问题

mmdetection 模型训练技巧相关推荐

  1. 大模型训练技巧|单卡多卡|训练性能评测

    原视频:[单卡.多卡 BERT.GPT2 训练性能[100亿模型计划]] 此笔记主要参考了李沐老师的视频,感兴趣的同学也可以去看视频- 视频较长,这里放上笔记,与大家分享- 大模型对于计算资源的要求越 ...

  2. 【Pytorch神经网络理论篇】 24 神经网络中散度的应用:F散度+f-GAN的实现+互信息神经估计+GAN模型训练技巧

    1 散度在无监督学习中的应用 在神经网络的损失计算中,最大化和最小化两个数据分布间散度的方法,已经成为无监督模型中有效的训练方法之一. 在无监督模型训练中,不但可以使用K散度JS散度,而且可以使用其他 ...

  3. 高效又稳定的ChatGPT大模型训练技巧总结,让训练事半功倍!

    文|python 前言 近期,ChatGPT成为了全网热议的话题.ChatGPT是一种基于大规模语言模型技术(LLM, large language model)实现的人机对话工具.现在主流的大规模语 ...

  4. 李宏毅老师《机器学习》课程笔记-2.1模型训练技巧

    注:本文是我学习李宏毅老师<机器学习>课程 2021/2022 的笔记(课程网站 ),文中图片除了两幅是我自己绘制外,其余图片均来自课程 PPT.欢迎交流和多多指教,谢谢! 文章目录 Le ...

  5. ML(10) - 模型训练技巧

    模型技巧 交叉验证 Pipeline 网格搜索 偏差(Bias)和方差(Variance) 模型正则化(Regularization) 正则化基本概念 正则化种类(scikit-learn) 交叉验证 ...

  6. 深度学习模型训练技巧

    博主以前都是拿别人的模型别人的数据做做分类啊,做做目标检测,搞搞学习,最近由于导师的工程需求,自己构造网络,用自己的数据来跑网络,才发现模型训练真的是很有讲究,很有技巧在里面,最直接的几个超参数的设置 ...

  7. 计算机视觉中的数据预处理与模型训练技巧总结

    来源丨机器学习小王子,转载自丨极市平台 导读 针对图像分类任务提升准确率的方法主要有两条:一个是模型的修改,另一个是各种数据处理和训练的技巧.本文在精读论文的基础上,总结了图像分类任务的11个tric ...

  8. 【干货】计算机视觉中的数据预处理与模型训练技巧总结

    来源丨机器学习小王子 编辑丨极市平台 针对图像分类任务提升准确率的方法主要有两条:一个是模型的修改,另一个是各种数据处理和训练的技巧.本文在精读论文的基础上,总结了图像分类任务的11个tricks. ...

  9. 模型训练技巧——warm up

    1. pytorch 中学习率的调节策略 (1)等间隔调整学习率 StepLR (2)按需调整学习率 MultiStepLR (3)指数衰减调整学习率 ExponentialLR (4)余弦退火调整学 ...

最新文章

  1. GoogleLog(GLog)源码分析
  2. 有关cmd.Parameters.Clear()
  3. UVa 11825 (状压DP) Hackers' Crackdown
  4. qdbus 复杂类型
  5. 你遇到过哪些理工科的实验高手,他们有哪些优秀的思维习惯?
  6. mysql索引类型normal,unique,full text
  7. js金额千分位显示_JavaScript 格式化数字、金额、千分位、保留几位小数
  8. 使用freemarker模板生成html文件(一)
  9. android源码学习-Handler机制及其六个核心点
  10. rs232接口_串口,COM口,TTL,RS232,RS485,UART的区别详解
  11. Saving Tang Monk II(bfs+优先队列)
  12. Python笔记:第三方IP代理服务与爬虫IP代理
  13. 计算机桌面文件删除不掉是怎么了,如何解决电脑桌面文件无法删除问题
  14. V4L2 pixel format 格式参考
  15. 美和易思——互联网技术学院返校周测题
  16. Vue仿淘宝购物车网页
  17. 浅谈在windows系统下esp8266和esp32开发共存一个eclipse编译器,非安信可一体化环境,而是搭建自己的eclipes环境。
  18. 计算机运维方向要考什么证,IT运维项目经理考的证
  19. python作业——SVM预测交通流量
  20. mac os版spyder 安装jieba报错 No module named ‘jieba‘

热门文章

  1. 校园招聘Java开发工程师需要掌握的技能
  2. 素数定理 nefu 117
  3. 水星迷你无线路由器ap模式 下要不要启用 dhcp服务器,水星(Mercury)Mini无线路由器Router模式设置...
  4. 全国三甲医院突破3000家,医疗格局正在生变
  5. 【信息收集自动化工具】
  6. MFC实现播放SWF
  7. CDOJ 1281 暴兵的卿学姐 构造题
  8. 在申请版号的路上,游戏厂商该如何做好防沉迷系统
  9. 2022年浙江省中职组“网络空间安全”赛项模块B--Linux渗透测试
  10. Vagrant Boxs