paper:Shunted Self-Attention via Multi-Scale Token Aggregation

github:https://github.com/OliverRensu/Shunted-Transformer

aistudio:没有GPU?Shunted Transformer 飞桨权重迁移在线体验传送门

ViT模型在设计时有个特点:在相同的层中每个token的感受野相同。这限制了self-attention层捕获多尺度特征的能力,从而导致处理多尺度目标的图片时性能下降。针对这个问题,作者提出了shunted self-attention,使得每个attention层可以获取多尺度信息。


目录

一、Shunted Self-Attention

二、Detail-specific Feedforward Layers

三、网络结构

四、实验结果

五、总结


一、Shunted Self-Attention

本篇论文的核心是提出了Shunted Self-Attention,几种不同的ViT模块对比如下:

ViT: QKV维度相同,可以得到全局感受野但是计算量大
Swin:划分window,self-attention在窗口内计算减少计算量,同时引入shift操作使得感受野增加
PVT:降低KV的patch数量来降低计算量
shunted Self-Attention:在单个attention层计算时得到多尺度KV,再计算Self-Attention

计算过程如下:

上式中,i表示KV尺度的个数,MTA(multi-scale token aggregation)表示下采样率为ri的特征聚合模块(通过带步长的卷积实现),LE是深度可分离卷积层,用来增强V中相邻像素的联系。

实现代码:

class Attention(nn.Module):def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):super().__init__()assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."self.dim = dimself.num_heads = num_headshead_dim = dim // num_headsself.scale = qk_scale or head_dim ** -0.5self.q = nn.Linear(dim, dim, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)self.sr_ratio = sr_ratioif sr_ratio > 1:self.act = nn.GELU()if sr_ratio==8:self.sr1 = nn.Conv2d(dim, dim, kernel_size=8, stride=8)self.norm1 = nn.LayerNorm(dim)self.sr2 = nn.Conv2d(dim, dim, kernel_size=4, stride=4)self.norm2 = nn.LayerNorm(dim)if sr_ratio==4:self.sr1 = nn.Conv2d(dim, dim, kernel_size=4, stride=4)self.norm1 = nn.LayerNorm(dim)self.sr2 = nn.Conv2d(dim, dim, kernel_size=2, stride=2)self.norm2 = nn.LayerNorm(dim)if sr_ratio==2:self.sr1 = nn.Conv2d(dim, dim, kernel_size=2, stride=2)self.norm1 = nn.LayerNorm(dim)self.sr2 = nn.Conv2d(dim, dim, kernel_size=1, stride=1)self.norm2 = nn.LayerNorm(dim)self.kv1 = nn.Linear(dim, dim, bias=qkv_bias)self.kv2 = nn.Linear(dim, dim, bias=qkv_bias)self.local_conv1 = nn.Conv2d(dim//2, dim//2, kernel_size=3, padding=1, stride=1, groups=dim//2)self.local_conv2 = nn.Conv2d(dim//2, dim//2, kernel_size=3, padding=1, stride=1, groups=dim//2)else:self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)self.local_conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, stride=1, groups=dim)self.apply(self._init_weights)def forward(self, x, H, W):B, N, C = x.shapeq = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)if self.sr_ratio > 1:x_ = x.permute(0, 2, 1).reshape(B, C, H, W)x_1 = self.act(self.norm1(self.sr1(x_).reshape(B, C, -1).permute(0, 2, 1)))x_2 = self.act(self.norm2(self.sr2(x_).reshape(B, C, -1).permute(0, 2, 1)))kv1 = self.kv1(x_1).reshape(B, -1, 2, self.num_heads//2, C // self.num_heads).permute(2, 0, 3, 1, 4)kv2 = self.kv2(x_2).reshape(B, -1, 2, self.num_heads//2, C // self.num_heads).permute(2, 0, 3, 1, 4)k1, v1 = kv1[0], kv1[1] #B head N Ck2, v2 = kv2[0], kv2[1]attn1 = (q[:, :self.num_heads//2] @ k1.transpose(-2, -1)) * self.scaleattn1 = attn1.softmax(dim=-1)attn1 = self.attn_drop(attn1)v1 = v1 + self.local_conv1(v1.transpose(1, 2).reshape(B, -1, C//2).transpose(1, 2).view(B,C//2, H//self.sr_ratio, W//self.sr_ratio)).\view(B, C//2, -1).view(B, self.num_heads//2, C // self.num_heads, -1).transpose(-1, -2)x1 = (attn1 @ v1).transpose(1, 2).reshape(B, N, C//2)attn2 = (q[:, self.num_heads // 2:] @ k2.transpose(-2, -1)) * self.scaleattn2 = attn2.softmax(dim=-1)attn2 = self.attn_drop(attn2)v2 = v2 + self.local_conv2(v2.transpose(1, 2).reshape(B, -1, C//2).transpose(1, 2).view(B, C//2, H*2//self.sr_ratio, W*2//self.sr_ratio)).\view(B, C//2, -1).view(B, self.num_heads//2, C // self.num_heads, -1).transpose(-1, -2)x2 = (attn2 @ v2).transpose(1, 2).reshape(B, N, C//2)x = torch.cat([x1,x2], dim=-1)else:kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)k, v = kv[0], kv[1]attn = (q @ k.transpose(-2, -1)) * self.scaleattn = attn.softmax(dim=-1)attn = self.attn_drop(attn)x = (attn @ v).transpose(1, 2).reshape(B, N, C) + self.local_conv(v.transpose(1, 2).reshape(B, N, C).transpose(1, 2).view(B,C, H, W)).view(B, C, N).transpose(1, 2)x = self.proj(x)x = self.proj_drop(x)return x

二、Detail-specific Feedforward Layers

在MLP中加入了Detail Specific分支(depth-wise卷积)来增强相邻像素的联系,与PVT的MLP不同是有了残差连接。

PS:源码中GELU的位置和残差连接的位置顺序与图相反,参考下方代码。

代码如下:

class Mlp(nn.Module):def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):super().__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.fc1 = nn.Linear(in_features, hidden_features)self.dwconv = DWConv(hidden_features)self.act = act_layer()self.fc2 = nn.Linear(hidden_features, out_features)self.drop = nn.Dropout(drop)def forward(self, x, H, W):x = self.fc1(x)x = self.act(x + self.dwconv(x, H, W))  # 残差连接,这里和图画的顺序不一样,图应该画错了x = self.drop(x)x = self.fc2(x)x = self.drop(x)return xclass DWConv(nn.Module):def __init__(self, dim=768):super(DWConv, self).__init__()self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)def forward(self, x, H, W):B, N, C = x.shapex = x.transpose(1, 2).view(B, C, H, W)x = self.dwconv(x)x = x.flatten(2).transpose(1, 2)return x

三、网络结构

网络结构如图所示,整体结构与大部分模型相同,区别在于内部的Transfmer block做出了上述改进,此外,该网络未使用cls_token和pos_embedding。

四、实验结果

在ImageNet-1k上表现如下:

五、总结

本文与PVT非常相似,主要改进了Self-Attention模块和MLP模块,获得了非常好的效果,很nice的工作。

【论文笔记】Shunted Self-Attention via Multi-Scale Token Aggregation 论文笔记及实验相关推荐

  1. Gated Mechanism for Attention Based Multi Modal Sentiment Analysis 阅读笔记

    GATED MECHANISM FOR ATTENTION BASED MULTIMODAL SENTIMENT ANALYSIS 阅读笔记 最近在跟进多模态的情感分析发现多模态榜一又被刷下来了,这篇 ...

  2. 《MA‑CRNN: a multi‑scale attention CRNN for Chinese text line recognition in natural scenes》论文阅读

    参考博文: CRNN的一个变种,可以读一读,看看相对于CRNN来说有什么变化?以及为什么? 文章目录 make decision step1:读摘要 step2:读Introduction step3 ...

  3. RFA-Net: Residual feature attention network for fine-grained image inpainting 论文阅读笔记

    RFA-Net: Residual feature attention network for fine-grained image inpainting 论文阅读笔记 摘要 尽管大多数使用生成对抗性 ...

  4. (Lightweight multi-scale aggregated residual attention networks for image super-resolution)阅读笔记

    轻量级多尺度残差注意力网络 Lightweight multi-scale aggregated residual attention networks for image super-resolut ...

  5. CSI笔记【6】:Guaranteeing spoof-resilient multi-robot networks论文阅读

    CSI笔记[6]:Guaranteeing spoof-resilient multi-robot networks论文阅读 Abstract 1 Introduction 1.1 Contribut ...

  6. 【ICLR 2018图神经网络论文解读】Graph Attention Networks (GAT) 图注意力模型

    论文题目:Graph Attention Networks 论文地址:https://arxiv.org/pdf/1710.10903.pdf 论文代码:https://github.com/Peta ...

  7. (CTC损失)Hybrid CTC/Attention Architecture for End-to-End Speech Recognition阅读笔记

    ASR-R (CTC损失)Hybrid CTC/Attention Architecture for End-to-End Speech Recognition阅读笔记 文章目录 ASR-R (CTC ...

  8. 联邦学习笔记-《Federated Machine Learning: Concept and Applications》论文翻译个人笔记

    联邦学习笔记-<Federated Machine Learning: Concept and Applications>论文翻译个人笔记 摘要 今天的人工智能仍然面临着两大挑战.一是在大 ...

  9. 【论文阅读】Online Attention Accumulation for Weakly Supervised Semantic Segmentation

    一篇弱监督分割领域的论文,其会议版本为: (ICCV2019)Integral Object Mining via Online Attention Accumulation 论文标题: Online ...

最新文章

  1. phpMyAdmin操作之改管理员密码
  2. Inside Linux kernel
  3. 012_JDBC模板
  4. mysqldump: Couldn't execute 'SAVEPOINT sp':
  5. Unity3D实践系列03,使用Visual Studio编写脚本与调试
  6. win10软件拒绝访问删不掉_Win10右键菜单添加“获取文件管理员权限”选项
  7. tps+qps+mysql_实时获取MySQL的TPS、QPS(输出到屏幕)
  8. grafana官方使用文档_5. Centos7 下部署使用 nmon2influxdb
  9. Linux虚拟机连不上网克隆虚拟机网卡无法启动
  10. Ubuntu查看有线网卡eth0和eth1分别对应网卡型号
  11. 子目录和子域名哪个好?子目录和子域名如何利用seo优化?
  12. 网络营销中词条推广的价值和注意事项
  13. 5gh掌上云计算认证不通过_华为云计算认证含金量怎么样?
  14. H1B政策大变,要集体涨工资了吗?
  15. controll层跳转页面_springmvc controller跳转页面问题
  16. 电子邮箱的格式怎么写,电子邮箱的正确格式填写时什么样子
  17. 任意门怎么用团发_如果发明了任意门,将会带来怎样的混乱?
  18. crm软件和ERP系统的关系如何?
  19. python 分类变量xgboost_python小白之路:第十九章 XGBoost
  20. CPU联盟潘榆文:百度侵权的大山虽然高不可攀,但我愿做愚公

热门文章

  1. 详解网付刷脸支付加盟代理
  2. Python实操模拟题
  3. 计算机网络课程设计聊天,计算机网络课程设计-简单聊天程序.doc
  4. 第四批鸿蒙手机排行,鸿蒙系统名单已确认!麒麟9000首批,这些手机将被淘汰......
  5. 数字经济时代,大数据产业将呈现持续增长趋势
  6. 手把手教你原生JavaScript打造丝滑流畅的轮播图,让你的网站瞬间提升用户体验
  7. SAP部分清账和剩余清账
  8. 中国式家长怎么解锁计算机入门,中国式家长特长解锁方法_中国式家长特长图鉴解锁心得_3DM单机...
  9. 闪兼云怎么样,与大家聊一聊闪兼云
  10. 我就是这样进入阿里巴巴的!