此篇论文已被 AAAI 2022 收录,论文链接请见“阅读原文”。

● 简介 

近年来,以 DETR[1]为代表的基于 transformer 的端到端目标检测算法开始广受大家的关注。这类方法通过一组目标查询来推理物体与图像上下文的关系从而得到最终预测结果,且不需要 NMS 后处理,成为了一种目标检测的新范式。

但是,这类方法尚有一些不足之处。

首先,DETR 解码器的目标查询是一组可学习的向量。这组向量人类难以解释,没有显式的物理意义。同时,目标查询对应的预测结果的分布也没有明显的规律,这也导致模型较难优化。

为了解决上述问题,本文提出了一种基于锚点的查询设计,因此目标查询有了显式的物理意义,且每个查询仅关注对应锚点附近的区域,使得模型更容易优化。

此外,本文还提出了一种 attention 结构的变种,可以显著降低显存消耗,且对于检测任务中较难的 cross attention 依旧能保持精度不降。

如表 1 所示,最终本文算法比 DETR 精度更高,消耗显存更少,速度更快,且收敛更快(所需训练轮次更少)。

表1

● Attention 回顾 ●

首先,我们回顾一下 DETR 中 attention 的形式: ,,

这里 Q、K 和 V 分别为查询、键和值,下标 f 和 P 分别表示特征和位置编码向量,标量  为特征的维度。实际上,Q、K 和 V 还会分别经过一个全连接层,这里为了简洁省略了这部分。

DETR 的解码器包含两种 attention,一种是 self-attention,另一种是 cross-attention。

在 self-attention 中,   和  与  一样,  与  一样。其中  由上一个解码器层的输出得到,第一个解码器层的  被初始化为一个常数向量,如零向量;而  设置为一组可学的向量,为解码器中所有的  共享: ,

在 cross-attention 中,  由之前的 self-attention 的输出得到;而  和  是编码器的输出特征;  是编码器输出特征对应的位置编码向量,DETR 采用了正余弦函数来作为位置编码函数,我们将该位置编码函数记作  ,若编码器特征对应的位置记作  ,那么:   在此解释一下,H, W, C 分别是特征的高、宽和通道数目,而  是预设的目标查询数目。

● 查询设计 ●

通常我们把解码器中的  认作是目标查询,这是因为它负责分辨不同的物体(解码器中的初始  为零向量没有分辨能力)。

如前文所述,DETR 中的目标查询  是一组可学向量,其难以解释且没有显式的物理意义。观察这些目标查询对应的预测结果的分布,如图 1 所示,每个方格中的点表示一个目标查询对应的所有图像预测结果的中心点,可以看到,每个查询都负责非常大的范围,且导致负责的区域有很大的重叠,这种模糊性也使得网络很难优化。

图 1

为了解决这个问题,本文提出基于锚点的查询设计,每个目标查询为锚点坐标的编码,因此具有了显式的物体意义。并且,每个查询仅关注锚点附近的区域,可使得网络模型更易优化。

在基于 CNN 的检测算法中,锚点通常都是特征网格点的坐标。而在本文中,锚点可以更加灵活。可以使用预设的网格位置的锚点,也可以是一组可以随网络学习的位置点。如图 2 所示,我们发现最终学习到的锚点分布与网格点较为相似,都是趋于均匀分布在整个图像上。这可能是因为在整个图像集中,图像的各个位置都会出现物体。

图 2

记锚点为  ,其表示有  个锚点,每个锚点记录点的(x,y)坐标。那么,基于锚点的目标查询则是: () 即目标查询为锚点坐标的编码。那么如何选择位置编码函数呢?最自然地,本文选择与键特征共享一样的位置编码函数: (),() 其中,g 为位置编码函数,它可以是前述的  ,也可以是其它的形式。在本文中我们对启发式的  额外加入了两个全连接层以更好地调整它。

更进一步考虑,有时一个位置可能会出现多个物体。显然,若一个锚点仅能预测一个物体的话,那么该位置的其它物体则需要其它位置的锚点来一同预测。这导致每个锚点负责的区域扩大,增加了其位置模糊性。为了解决这个问题,本文对每个锚点加入多种模式,使其可以有多个预测。

回顾 DETR,其中初始的查询特征为  ,对于  个目标查询来说,每个都只有一种模式  ,其中  表示目标查询的索引。

因此,本文为每个目标查询设置多种模式  ,其中  为模式的数目,是一个较小的值,如  =3。具体而言,本文使用一组可学向量 , 作为目标查询的多种模式。考虑移动不变性,我们希望这些模式与位置无关,因此让各个锚点共享多种模式。如此,我们便可得到增广的初始查询特征  和查询位置编码  。

观察改进后的目标查询对应的预测结果的分布,如图 3 所示,其中最后一行为锚点,前三行是对应锚点的三种模式的预测,可以看到,基于锚点的查询将关注锚点附近的区域,查询对应的预测框中心点都分布在锚点周围。此时查询不需要预测离对应锚点很远的物体,因此其具体特定的语义,从而模型将更容易优化。

图 3

图 4 展示了各个查询模式对应预测的分布,可以发现模式与物体大小存在一定关系,例如大物体几乎都出现在模式(a)中,模式(b)则关注小物体,模式(c)介于两者之间。另外我们还可以发现,所有的模式都会预测小物体,这是因为小物体更容易出现一个位置多个物体的情况。

图  4

● Attention 变种 ●

目前许多的 attention 变种,如 Deformable DETR[2]、Efficient Attention[3]等,都可以大幅度降低 transformer 占用的显存。然而,也许是由于 DETR 类方法中 transformer 解码器的 cross attention 较难,若使用同样的特征,这些方法将会导致一定程度上的精度降低。

本文提出了一种行列特征解耦的 attention 变种(Row-Column Decoupled Attention, RCDA),将键特征解耦为列特征和行特征,再依次进行列 attention 和行 attention。该方法不仅可以降低显存消耗,还可以得到和原先的标准 attention 相似或者更高的精度。

首先,对于键特征  ,先将其解耦为行特征  和列特征  ,本文采用的解耦方式为分别沿着列和行做均值。

接下来,则可以分别计算查询对于行、列键特征的注意力图: ()()

其中, ,,

(),(),,(),(),

最后则依据行列注意力图,对值特征依次沿着行、列进行加权和。不失一般的,我们假设 W≤H,可如下式先沿着列加权,再沿着行加权(若 W>H,则可先沿着行加权,使中间结果的显存占用小一些):

其中  行列解耦的 attention 变种的原理上文便介绍完了,现在我们再来讨论一下它为什么可以节省显存。

在之前的表述中,我们不失一般的假设 Attention 头的数目为 1 以更加简洁,现在我们设其为 M。在标准的 Attention 中,注意力图  为主要的显存占用瓶颈,而在行列解耦的attention中,行列注意力图  和  的显存远小于标准 attention 中的注意力图。

由于特征的通道数目 C 通常大于 M,RCDA 的中间结果 Z 的显存占用要大于行列注意力图,因此我们主要比较 RCDA 的中间结果  与标准Attention 中注意力图  之间的关系。显然,随着图像特征分辨率的增大(H 与 W 增大),标准  attention  的显存占用增长得更快。

行列解耦  attention  较标准  attention  可以节省显存的倍数为:

在默认的设置中,M=8,C=256,因此当特征长边 H 大于 32 时,RCDA 可以节省显存。在目标检测任务中,特征边长 32 是 C5 特征的一个典型值,因此使用 C5 特征显存占用相差不大,使用更大的 C4 特征显存可省 2 倍,依次类推。

● 总体流程 ●

算法的总体流程如图 5 所示,首先通过 CNN 网络提取图像特征,然后再经过transformer 编码器通过 self attention 处理图像特征,输出的图像特征将作为解码器的键和值特征。解码器的查询为前文所述基于锚点的多模式查询,在解码器中,各个查询分别根据注意力图聚合感兴趣的图像特征,最后输出最终的预测结果。预测框的中心点预测相对锚点的偏移量,而框的大小则预测其相对图像的大小。编码器和解码器中的 attention 可以采用标准的 attention,也可以采用本文所述的行列解耦 attention。对于 attention 中各特征的位置编码,则依据其位置使用共享的位置编码函数得到。

图 5

● 实验分析 ●

如表 2 所示,我们比较了本文算法与其它一些算法的性能比较,默认的骨干网络为 ResNet50。可以看到本文算法可以到达较好的性能,且继承了 DETR 无需手工设计锚框、无需 NMS 后处理,且不涉及随机内存访问的优秀性质。

表 2

不涉及随机内存访问(RAM-free)的性质可以减小硬件的访存代价,在实用中对硬件更加友好。

举个例子,假设有个人(专用计算芯片)力气很大(算力很强),他可以轻松地把一叠共 1000 张纸搬到指定的地方(计算处理某个张量)。而假如让他取出其中的第 123 张和第 234 张纸搬到指定的地方,需要搬的纸虽然少了很多(计算量大幅度降低),但是由于需要找到这些指定的纸(随机内存访问),可能会更加费时(访存代价增加)。

通常来说,两阶段的检测算法由于感兴趣的区域(RoI)的坐标对硬件来说随机的,提取感兴趣区域的特征会涉及到随机内存访问。而 Deformable DETR 也涉及到提取特定坐标的特征的情况,因此也非 RAM-free。

如表 3 所示,我们还分析了上述所提各个模块的效果。首先,我们可以看到所提的查询设计,即将锚点(anchors)编码为解码器查询以及为锚点加入多种模式(patterns),可以将性能从 39.3 提升至 44.2,这个显著的提升表明了本文的查询设计较原 DETR 查询设计的优越性。我们还可以看到,将标准的Attention 替换为 RCDA 性能近乎一致,这表明 RCDA 可以无损地降低显存占用。还有一点比较有趣的现象,若我们为原 DETR 的查询加入多种模式,其性能没有明显的变化,我们认为这是因为 DETR 的查询与位置没有明显关系,不能获得解决“一个位置多个物体”问题的收益。

表 3

如表 4 所示,我们比较了使用不同数目的锚点(anchor points)和模式(patterns)。100 个锚点数目过少性能较低,而 900 个锚点性能与 300 个锚点相差仅 0.3,因此我们默认使用 300 个锚点。可以看到,为每个锚点设置多种模式,性能会有明显的提升。另外,当预测结果的数目一致时,即保持锚点数目乘以模式数目的值不变时,多种模式的性能也比一种模式效果更好,这说明了多种模式的提升并非是因为预测的数目增加,而是本质更好。

表 4

如表 5 所示,我们比较了 attention 变种的效果。可以看到,Efficient Attention 虽然可以大幅度降低显存,但是由于 cross attention 较难,效果有明显下降。而本文的 RCDA 将显存占用从 10.5G 降低至 4.4G,而精度却没有明显变化。

表 5

● 总结 ●

本文提出了一个基于 transformer 的检测算法,其实现简单,且比 DETR 精度更高,消耗显存更少,速度更快,且收敛更快。

● 参考文献 ●

[1] Carion N, Massa F, Synnaeve G, et al. End-to-end object detection with transformers[C]//European conference on computer vision. Springer, Cham, 2020: 213-229.

[2] Zhu X, Su W, Lu L, et al. Deformable detr: Deformable transformers for end-to-end object detection[J]. arXiv preprint arXiv:2010.04159, 2020.

[3] Shen Z, Zhang M, Zhao H, et al. Efficient attention: Attention with linear complexities[C]//Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision. 2021: 3531-3539.

论文解读 | 锚点 DETR:基于 transformer 目标检测的查询设计相关推荐

  1. CVPR2020论文解读:3D Object Detection三维目标检测

    CVPR2020论文解读:3D Object Detection三维目标检测 PV-RCNN:Point-Voxel Feature Se tAbstraction for 3D Object Det ...

  2. 论文解读:《功能基因组学transformer模型的可解释性》

    论文解读:<Explainability in transformer models for functional genomics> 1.文章概括 2.背景 3.相关工作 4.方法 4. ...

  3. AAAI 2020 Oral论文--TANet:提升点云3D目标检测的稳健性

    点击上方"深度学习技术前沿",选择"星标"公众号 资源干货,第一时间送达 来自华中科技大学白翔教授组的刘哲的 AAAI Oral 论文<TANet: Ro ...

  4. 如何使用CNN进行物体识别和分类_基于CNN目标检测方法(RCNN系列,YOLO,SSD)

    转载自:基于CNN目标检测方法(RCNN,Fast-RCNN,Faster-RCNN,Mask-RCNN,YOLO,SSD)行人检测 一.研究意义 卷积神经网络(CNN)由于其强大的特征提取能力,近年 ...

  5. 基于轻量级目标检测模型实现手写汉字检测识别计数

    一般手写相关的数据集,应该是手写数字听得最多最多的,手写汉字也有,但是与手写数字或者是手写字母的知名度相比就低了很多很多,在我前面的一篇很早期入门的时候写过一篇文章,如下: <Yolov3目标检 ...

  6. php 教学目标,基于教学目标,共话教学设计

    基于教学目标,共话教学设计 --一课研究团队11月线下集中培训第1天掠影 11月9日,"朱乐平小学数学名师工作站 · 一课研究团队"在杭州胜利实验小学开展集中培训活动,数百名教师齐 ...

  7. 目标检测实战必会!4种基于YOLO目标检测(Python和C++两种版本实现)

    目标检测实战必会!4种基于YOLO目标检测(Python和C++两种版本实现) AI算法修炼营 1周前 以下文章来源于极市平台 ,作者CV开发者都爱看的 极市平台 专注计算机视觉前沿资讯和技术干货,官 ...

  8. 目标检测 - Neck的设计 PAN(Path Aggregation Network)

    目标检测 - Neck的设计 PAN(Path Aggregation Network) flyfish 目标检测器的构成 1. Input:Image,Patches,ImagePyramid 2. ...

  9. 基于单片机家具窗帘控制系统设计、基于单片机路灯教室灯光家具智能控制设计-基于单片机简易电饭煲电饭锅仿真系统设计、基于单片机酒精检测控制系统仿真设计-设计资料

    1426基于单片机酒精检测控制系统仿真设计-全套资料 (1)  学习气体测量传感器的原理和使用,并完成数据采集.调理电路的设计: (2)  学习单片机系统的设计及编程,完成系统整体设计: (3)  通 ...

最新文章

  1. 史上最全数据结构算法之递归系列学习,建议收藏!
  2. 计算机二级公共基础知识总结百度云,计算机二级公共基础知识总结详细版本[精]...
  3. 【php】windows安装PHP5.5+Apache2.4
  4. 腾讯员工吐槽:团队来了个阿里高p,瞬间会议变多,群多了
  5. 获取本地System权限
  6. 运行测试Caused by: java.lang.UnsatisfiedLinkError: no attach in java.library.path错误解决
  7. Oracle中row_number()、rank()、dense_rank() 的区别
  8. express 创建ejs项目,使用html
  9. ACM-Maximum Tape Utilization Ratio
  10. 芒果超媒:子公司与咪咕文化签署合作框架协议
  11. html中rem和em,CSS 中的 rem 和 em 的区别(1)
  12. c语言实现将文本转换为语音,C#文字转换语音朗读或保存MP3、WAV等格式
  13. 更改win11鼠标指针样式
  14. JS时间增加2个小时
  15. Monkey log 分析
  16. 如何看待社会的阴暗面
  17. Python中Print()函数的用法___实例详解(全,例多)
  18. Apple Configurator 2使用教程: 修复或恢复搭载 Apple M1芯片的 Mac!
  19. CATIA CAA二次开发专题(十)---迷宫中穿行(终结篇)
  20. IntelliJ IDEA/Android Studio 翻译插件,可中英互译。

热门文章

  1. 不止GPU!这些硬件也影响着深度学习训练速度
  2. opencv图像傅里叶变换
  3. 空间音频已来,TWS 4.0时代已来!TWS200带来全新听觉体验
  4. 信安Note_day27
  5. 斐讯发力渲染云 打造“互联网+”文创产业链
  6. AdServices归因和iAd归因集成
  7. Matlab的数组索引
  8. 北上广深杭房价高压下,这也许是程序员扎根的唯一出路...
  9. 互联网产品设计进阶(10)关注项目的赢利模式
  10. Eclipse中进行web project开发时遇到httperror 500 错误 jsp support not configured 问题的解决方法