YOLOv5、v7改进之三十一:CrissCrossAttention注意力机制
前 言:作为当前先进的深度学习目标检测算法YOLOv7,已经集合了大量的trick,但是还是有提高和改进的空间,针对具体应用场景下的检测难点,可以不同的改进方法。此后的系列文章,将重点对YOLOv7的如何改进进行详细的介绍,目的是为了给那些搞科研的同学需要创新点或者搞工程项目的朋友需要达到更好的效果提供自己的微薄帮助和参考。由于出到YOLOv7,YOLOv5算法2020年至今已经涌现出大量改进论文,这个不论对于搞科研的同学或者已经工作的朋友来说,研究的价值和新颖度都不太够了,为与时俱进,以后改进算法以YOLOv7为基础,此前YOLOv5改进方法在YOLOv7同样适用,所以继续YOLOv5系列改进的序号。另外改进方法在YOLOv5等其他算法同样可以适用进行改进。希望能够对大家有帮助。
具体改进办法请关注后私信留言!
解决问题:之前改进增加了很多注意力机制的方法,包括比较常规的SE、CBAM等,本文加入CrissCrossAttention注意力机制,该注意力机制为应用在语义分割中的模块,用于可以让网络更加关注待检测目标,提高检测效果
基本原理:
语义分割的Criss-Cross网络(CCNet)的细节。我们首先介绍了CCNet的总体框架。然后,将介绍在水平和垂直方向捕获上下文信息的2D交叉注意力模块。为了获取密集的全局上下文信息,我们建议对交叉注意力模块采用循环操作。为了进一步改进RCCA,我们引入了判别损失函数来驱动RCCA学习类别一致性特征。最后,我们提出了同时利用时间和空间上下文信息的三维交叉注意模块。
添加方法:
第一步:确定添加的位置,作为即插即用的注意力模块,可以添加到YOLOv5网络中的任何地方。
第二步:common.py构建CoordAtt模块。部分代码如下,关注文章末尾,私信后领取。
class CrissCrossAttention(nn.Module):""" Criss-Cross Attention Module"""def __init__(self, in_dim):super(CrissCrossAttention, self).__init__()self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)self.softmax = Softmax(dim=3)self.INF = INFself.gamma = nn.Parameter(torch.zeros(1))def forward(self, x):m_batchsize, _, height, width = x.size()proj_query = self.query_conv(x)proj_query_H = proj_query.permute(0, 3, 1, 2).contiguous().view(m_batchsize * width, -1, height).permute(0, 2,1)proj_query_W = proj_query.permute(0, 2, 1, 3).contiguous().view(m_batchsize * height, -1, width).permute(0, 2,1)proj_key = self.key_conv(x)proj_key_H = proj_key.permute(0, 3, 1, 2).contiguous().view(m_batchsize * width, -1, height)proj_key_W = proj_key.permute(0, 2, 1, 3).contiguous().view(m_batchsize * height, -1, width)proj_value = self.value_conv(x)proj_value_H = proj_value.permute(0, 3, 1, 2).contiguous().view(m_batchsize * width, -1, height)proj_value_W = proj_value.permute(0, 2, 1, 3).contiguous().view(m_batchsize * height, -1, width)energy_H = (torch.bmm(proj_query_H, proj_key_H) + self.INF(m_batchsize, height, width)).view(m_batchsize, width,height,height).permute(0,energy_W = torch.bmm(proj_query_W, proj_key_W).view(m_batchsize, height, width, width)concate = self.softmax(torch.cat([energy_H, energy_W], 3))att_H = concate[:, :, :, 0:height].permute(0, 2, 1, 3).contiguous().view(m_batchsize * width, height, height)# print(concate)# print(att_H)att_W = concate[:, :, :, height:height + width].contiguous().view(m_batchsize * height, width, width)out_H = torch.bmm(proj_value_H, att_H.permute(0, 2, 1)).view(m_batchsize, width, -1, height).permute(0, 2, 3, 1)out_W = torch.bmm(proj_value_W, att_W.permute(0, 2, 1)).view(m_batchsize, height, -1, width).permute(0, 2, 1, 3)# print(out_H.size(),out_W.size())return self.gamma * (out_H + out_W) + x
第三步:yolo.py中注册 CrissCrossAttention模块
elif m is CrissCrossAttention:c1, c2 = ch[f], args[0]if c2 != no:c2 = make_divisible(c2 * gw, 8)args = [c1, *args[1:]]
第四步:修改yaml文件,本文以修改head(特征融合网络)为例,将原C3模块后加入该模块。
backbone:# [from, number, module, args][[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2[-1, 1, Conv, [128, 3, 2]], # 1-P2/4[-1, 3, C3, [128]],[-1, 1, Conv, [256, 3, 2]], # 3-P3/8[-1, 6, C3, [256]],[-1, 1, Conv, [512, 3, 2]], # 5-P4/16[-1, 9, C3, [512]],[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32[-1, 3, C3, [1024]],[-1, 1, SPPF, [1024, 5]], # 9[-1, 1, CrissCrossAttention, [1024]],]
第五步:将train.py中改为本文的yaml文件即可,开始训练。
结 果:本人在遥感数据集上进行实验,有涨点效果。需要请关注留言。
预告一下:下一篇内容将继续分享深度学习算法相关改进方法。有兴趣的朋友可以关注一下我,有问题可以留言或者私聊我哦
PS:该方法不仅仅是适用改进YOLOv5,也可以改进其他的YOLO网络以及目标检测网络,比如YOLOv7、v6、v4、v3,Faster rcnn ,ssd等。
最后,希望能互粉一下,做个朋友,一起学习交流。
YOLOv5、v7改进之三十一:CrissCrossAttention注意力机制相关推荐
- 改进YOLOv5系列:13.添加CrissCrossAttention注意力机制
- 目标检测算法——YOLOv5/YOLOv7改进之结合CBAM注意力机制
深度学习Tricks,第一时间送达 论文题目:<CBAM: Convolutional Block Attention Module> 论文地址: https://arxiv.org/p ...
- YOLOv5、v7改进之三十二:引入SKAttention注意力机制
前 言:作为当前先进的深度学习目标检测算法YOLOv7,已经集合了大量的trick,但是还是有提高和改进的空间,针对具体应用场景下的检测难点,可以不同的改进方法.此后的系列文章,将重点对YOLOv7 ...
- 改进YOLOv5系列:16.添加SKAttention注意力机制
最新创新点改进推荐 -
- 改进YOLOv5系列:21.添加CBAM注意力机制
- 《YOLOv5/v7改进实战专栏》专栏介绍 专栏目录
- 目标检测算法——YOLOv5/v7改进之结合最强视觉识别模块CotNet(Transformer)
- [YOLOv7/YOLOv5系列算法改进NO.33]引入GAMAttention注意力机制
前 言:作为当前先进的深度学习目标检测算法YOLOv7,已经集合了大量的trick,但是还是有提高和改进的空间,针对具体应用场景下的检测难点,可以不同的改进方法.此后的系列文章,将重点对YOLOv7 ...
- 小目标检测3_注意力机制_Self-Attention
主要参考: (强推)李宏毅2021/2022春机器学习课程 P38.39 李沐老师:64 注意力机制[动手学深度学习v2] 手把手带你Yolov5 (v6.1)添加注意力机制(一)(并附上30多种顶会 ...
最新文章
- 有没有什么好的C++视频教程?
- Know More About Oracle Row Lock
- mysql navicat授权_Mysql授权允许远程访问解决Navicat for MySQL连接mysql提示客户端不支持服务器请求的身份验证协议;考虑升级MySQL客户端...
- 归并排序 Java实现 简单易懂
- Ardunio开发实例-雨滴传感器
- Node之HTTPS客户端
- vdbench安装及使用
- 在注册表里删除没用的服务
- 小甲鱼c语言_Tip:一起做一个平平无奇的程序小天才吧
- 数字信号处理---模拟信号数字处理方法
- 【大数据】即席查询引擎Presto简单介绍
- 第5章 演绎推理与归纳推理
- 断网重启路由器就好_电脑断网重启路由器就好了怎么回事
- Educational Codeforces Round 60 (Rated for Div. 2) E. Decypher the String(构造)
- unity 纹理压缩格式‘_游戏制作行业为什么使用TGA格式的贴图而不使用PNG格式?...
- 计算机硬件找不到网络适配器,图文学习网络适配器不见了
【操作教程】
的恢复方法_...
- 渐变:线性渐变、径向渐变
- 无业务不技术:那些誓用区块链重塑的行业,发展怎么样了?
- python大气校正_全自动多源遥感影像大气校正方法
- 古代玻璃制品的化学成分分析与鉴别
热门文章
- MMDetection3d对KITT数据集的训练与评估介绍
- Codec2入门:框架解析
- Linux启用显卡opengl,如何使你的Nvidia显卡支持OpenGL?
- JavaScript学习笔记|数据类型——Object类型、for in循环
- SpringBoot项目中使用set方法后,自动保存问题
- 成都市2021年高考三诊成绩查询,2020年成都各校高三“三诊”成绩一览表
- java线程(16)——死锁讲解,白雪公主与灰姑娘抢口红和镜子的案例
- 一个即将走向社会的软件技术专业学生的感想
- 关于机房环境监控系统基础知识
- 初来北京的两三天-我被这个地方感动的想流泪