目录

一、概要

二、具体解析

1. 相对位置索引计算第一步

2. 相对位置索引计算第二步

3. 相对位置索引计算第三步


一、概要

Swin Transformer采用了相对位置编码的概念。

那么相对位置编码的作用是什么呢?  

 解释:在解释相对位置编码之前,我们需要先了解一下在NLP中Position Encoder即PE,NLP中Position_Encoder理解

在Swin Transformer中,将特征图 如按7*7 的窗口大小划分为多个小窗格,单独在每个小窗格内进行Attention计算。这样一来,窗口内就相当于有        49个Token即49个像素值,这些像素是有一定的位置关系的,故在Attention计算时,需要考虑这些像素的位置关系,故提出了相对位置编码,其与NLP        中的PE是有异曲同工之妙的。

而不同的是NLP中是在QK.T之前加入了Position信息,而Swin Transformer是在QK.T之后加入的相对位置信息,但是在效果上都是一样的。

维度解析:

如果特征图的大小为2*2*N(N表示每个像素点的channels),那么经过拉直之后Q、K、V的维度都为4*N,那么QK.T 的维度就是4*4,其中第一个4表示4个像素点第二个4表示对于每个像素点,相对包括自己在内的四个像素点的重要程度;而相对位置编码要得到的结果也需要是4*4,其每行表示四个像素相对于某个固定像素的位置编码值

那么我们求出的相对位置编码就是对应的编码值吗?

答案是否定的,求出的相对位置编码只是对应的位置索引,其索引值取值范围为 0 ~ K,而这个索引其实对应的是一个长度为K的可学习向量

这个可学习向量会在训练过程中逐步更新,而相对位置索引,就是提供索引值,从这个可学习向量中得到最终的位置编码值。如下图所示:

而接下来我们要做的就是,用尽可能形象的方式,解释明白这个相对位置索引矩阵是怎么获取的,计算公式为:

其中的B就是根据相对位置索引矩阵(上图右侧)中的每个像素位置的索引,从可学习向量中获取的值,并组成的编码矩阵(上图左侧)

二、具体解析

假设输入的feature map高宽都为2,那么首先我们可以构建出每个像素的绝对位置(左下方的矩阵),对于每个像素的绝对位置是使用行号y和列号x表示的。

1. 相对位置索引计算第一步

比如蓝色的像素对应的是第0行第0列所以绝对位置索引是( 0 , 0 ),蓝色像素使用q与所有像素k进行匹配过程中,是以蓝色像素为参考点而相对位置偏置Bias就是相对每个像素情况下,不同QK的偏移值

那么其他像素相对于该蓝色像素的相对位置是多少呢?

用蓝色像素的绝对位置索引与其他位置索引进行相减,就得到其他位置相对蓝色像素的相对位置索引,如下图所示

黄色(0,1)位置:(0,0) - (0,1) = (0,-1)

红色(1,0)位置:(0,0) - (1,0) = (-1,0)

绿色(1,1)位置:(0,0) - (1,1) = (-1,-1)

蓝色(0,0)位置相对于自己那就是(0,0)-(0,0)= (0,0)

如下图所示,每个位置都是相对于蓝色(0,0)位置的相对值,其实就是差值。

将其拉直后就为:

同理,当其他位置作为相对位置时,计算方式是一样的,都是让当前元素与其他四个位置的坐标位置相减。结果分别为:

将它们拉直后,分别为:

那么将上面拉直后的结果,放在一起的话,如下图所示:

那么用代码是怎么计算的呢,是每个位置单独计算后,再拼接在一起的吗?

答案是否定的,往往在矩阵相关的计算中,都会以矩阵的方式进行统一计算。代码如下:

先整体来看,后面回分步解析。

# 获取特征图所有像素点的位置坐标
coords_h = torch.arange(2)
coords_w = torch.arange(2)
coords = torch.meshgrid([coords_h, coords_w])
# 横纵坐标合并后并拉直
coords = torch.stack(coords)
coords_flatten = torch.flatten(coords, 1)
# 计算坐标的相对位置差值
relative_coords_first = coords_flatten[:, :, None]
relative_coords_second = coords_flatten[:, None, :]
relative_coords = relative_coords_first - relative_coords_second
relative_coords = relative_coords.permute(1, 2, 0).contiguous() 

分步解析:

(1)获取所有像素点的横坐标与纵坐标

a. 获取纵坐标的取值范围

coords_h = torch.arange(2)
coords_w = torch.arange(2)
'''
coords_h:[0,1]
coords_w:[0,1]
'''

b.获取所有位置的纵坐标与横坐标

coords = torch.meshgrid([coords_h, coords_w])
'''
coords[0]:
[[0,0][1,1]]
shape与特征图大小相同2x2,每个位置的值表示该像素点
的纵坐标,第一行纵坐标均为0,第二行纵坐标均为1coords[0]:
[[0,1[0,1]]
shape与特征图大小相同2x2,每个位置的值表示该像素点
的横坐标,第一列横坐标均为0,第二列横坐标均为1
'''

c. 上面的coords是一个列表,里面包括两个矩阵,即横坐标矩阵与纵坐标矩阵

将横坐标矩阵与纵坐标矩阵拼接起来,torch.stack,增加一个dim=0维度并拼接

coords = torch.stack(coords)
'''
coords:
shape:(2,2,2),第一个2表示横纵两种坐标,后面的2表示两行两列
'''
coords_flatten = torch.flatten(coords, 1)
'''横坐标与纵坐标分别拉直torch.flatten(coords,1)表示从第1个维度起拉直,shape(2,2,2) -> (2,4)
'''

d. 一共四个像素点,让每个像素点都其他包括自己在内的四个像素点横纵坐标求差值

所以就需要以行为单位(纵坐标与横坐标)每行各赋值4次,相当于每个像素的横纵坐标都复制4次,用于与四个像素点进行计算。

relative_coords_first = coords_flatten[:, :, None]
'''增加一个维度,用于在以列为单位复制4次shape(2,4,1)
'''

横坐标与纵坐标都分别复制了4份。

复制后的,每行表示每个像素的横或纵坐标复制了4次。

relative_coords_first = coords_flatten[:, None, :]
'''增加一个维度,用于在以行为单位复制4次shape:(2,1,4)
'''

所有坐标都分别复制4次。

复制后,每行表示所有像素的横或纵坐标。

relative_coords_first = coords_flatten[:, :, None]
relative_coords_second = coords_flatten[:, None, :]
relative_coords = relative_coords_first - relative_coords_second
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
'''上面的相减采用了广播机制,其广播的流程与上述的复制过程是一致的
'''

  上面这种做法是为了什么?

是为了两个矩阵相减,得到的结果就相当于,四个像素点依次作为主像素点时,其他四个像素相对于该主像素点的相对位置。如下图所示:

2. 相对位置索引计算第二步

注意!!! 这里描述的一直是相对位置索引,并不是相对位置偏执参数。后面我们会根据相对位置索引去取对应的参数。

上面已经计算出来相对某一个像素,其他像素点与其的坐标差值,如下:

但是上面的结果是二维的,而最终获取的位置参数表对于每个Head来说是一维的,故需要将上面的这个结果转换为一维的形式。由于索引值的范围为[-M+1,M-1],原始的相对位置索引加上M-1,使得索引值大于等于0,变为[0,2M-2]。

为什么要将索引值变为大于等于0呢?

这个问题其实很简单,因为我们在最后从参数表中获取最终值的方式,是通过索引,而索引值是不小于0的。

代码如下:

relative_coords[:, :, 0] += 2 - 1
relative_coords[:, :, 1] += 2 - 1
relative_coords[:, :, 1] += 2 - 1

3. 相对位置索引计算第三步

对与每行,即不同像素间,希望得到的索引位置是不同的,但是如果直接横纵坐标相加的话,往往会出现像素不同,索引相同的情况,如下所示:

所以最后将所有横坐标都乘上2M-1,最后再将横坐标和纵坐标求和,这样每行不同像素间得到的索引就具有独一性。

relative_coords[:, :, 0] *= 2 * 2 - 1

最后将行标和列标进行相加,得到独一的一维的索引,这样即保证了相对位置关系,而且不会出现上述0 +1 = 1  + 0 的问题了,是不是很神奇。

relative_position_index = relative_coords.sum(-1)

至此就计算出了相对位置的索引,其并不是公式中的位置偏置参数。

真正使用到的可训练参数使保存在相对位置偏置表 relative position bias table中的,这个表的size为9,因为上面矩阵中索引值为0到8 是9个数。

即N = (2M-1)* (2M-1) = (4-1) * (4-1) =9.其是可训练的,随着训练过程,其内部的数值是不断优化更新的。

self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) 

如relative position bias table如下所示:

所以可以,操作相对位置索引的数值,依次从table中获取对应的参数

至此,最终的相对位置编码才计算完毕。总体流程如下图所示:

Swin Transformer之相对位置编码详解相关推荐

  1. positional encoding位置编码详解:绝对位置与相对位置编码对比

    目录 前言 Why What 绝对位置编码 相对位置编码 Sinusoidal Position Encoding Complex embedding How 前言 相信熟悉BERT的小伙伴对posi ...

  2. Python2.7字符编码详解

    Python2.7字符编码详解 目录 Python2.7字符编码详解 声明 一. 字符编码基础 1.1 抽象字符清单(ACR) 1.2 已编码字符集(CCS) 1.3 字符编码格式(CEF) 1.3. ...

  3. 转1:Python字符编码详解

    Python27字符编码详解 声明 一 字符编码基础 1 抽象字符清单ACR 2 已编码字符集CCS 3 字符编码格式CEF 31 ASCII初创 311 ASCII 312 EASCII 32 MB ...

  4. 可能是最详细的字符编码详解

    Created By JishuBao on 2019-04-02 12:38:22 Recently revised in 2019-04-03 12:38:22   欢迎大家来到技术宝的掘金世界, ...

  5. 字符编码详解及利用C++ STL string遍历中文字符串

    作者:非妃是公主 专栏:<笔记><C++> 博客地址:https://blog.csdn.net/myf_666 个性签:顺境不惰,逆境不馁,以心制境,万事可成.--曾国藩 文 ...

  6. unicode编码详解_转载

    unicode编码详解,一看就懂  转载--https://www.cnblogs.com/hahlzj/p/11908713.html 一.Unicode编码 1 UTF-8 -16 -32编码和U ...

  7. 嵌入式汉字显示原理及GBK编码详解

    嵌入式汉字显示原理及GBK编码详解 ~~~~~~~~        关于各个编码的介绍和转换可以看我的另一篇博客:[C语言实现]十六进制面值转字符串.字符面值转十六进制.UNICODE与GBK互转,U ...

  8. Python字符编码详解

    Python字符编码详解 转自http://www.cnblogs.com/huxi/archive/2010/12/05/1897271.html Python字符编码详解 本文简单介绍了各种常用的 ...

  9. 深入理解transformer中的位置编码

    文章目录 总览 问题1 问题2 问题3 问题4 问题5 问题6 总览 我们今天需要讲解transformer中的位置编码,这其实属于进阶内容.既然你会到这里,我默认你已经看过了transformer的 ...

最新文章

  1. 前后端交互概述与URL地址格式
  2. oracle 物化视图 ORA-23413: 表 xxx.xx 不带实体化视图日志
  3. 百度地图上进行空间插值---反距离加权法
  4. 计算机会计综合作业,20年7月东财《通用财务软件X》综合作业(100分)
  5. 报错,但不影响运行ERROR: JDWP Unable to get JNI 1.2 environment, jvm-GetEnv() return code = -2...
  6. opencv中的Mat类型
  7. chrome应用程序无法启动因为并行配置不正确_Win8打不开软件提示并行配置不正确的解决方法...
  8. Solr(二)创建索引和查询索引的基本应用
  9. 前端可视化开发-编辑器
  10. ai图像处理软件集大成者:Leawo PhotoIns Pro中文版介绍
  11. 面试题:如果办公室一台电脑无法上网,你的排查方法?
  12. 原备案在腾讯云 如何操作新增网站备案
  13. 达梦数据库表被锁住后解锁方法
  14. 对数似然比LLR公式的问题
  15. 隆重纪念鲁宾逊诞辰,不走样,不离谱
  16. rtx2060什么水平_我的GAMING之路 篇八:光追到底是什么鬼?—微星VENTUS RTX2060评测...
  17. 软件项目管理作业实施方案 草案
  18. WEB安全之:Access 数据库 SQL 注入
  19. 树莓派 官方800万摄像头 参数
  20. MongoDB建模场景

热门文章

  1. python爬取加密qq空间_使用python+selenium爬取qq空间好友动态
  2. Ionic2中的相册选择和拍照上传——ImgService
  3. matlab做基尼曲线,计算基尼系数和matplotlib绘制洛伦兹曲线
  4. u盘乱码怎么做?这些正确做法你知道吗?
  5. 掌握软文营销三要素成功写出说服力文案
  6. 网络映射环境搭建的3种方法
  7. node环境变量配置
  8. 你没订单,也许是因为你不够人情味! [外贸 跟进订单 社交营销]
  9. 小米小爱音箱Pro8安装app_小米发布小爱音箱Art电池版:超大容量只要399
  10. Qt Solidworks零件中文特征名转换成英文