文章目录

  • 前言
  • 完整代码
  • GitHub链接
  • 附录:可视化resnet50最后一层特征图

前言

 之前一直比较好奇Conditional Detr中如何可视化各个头部的空间注意力热图的,于是,本人尝试在Detr基础上实现了一个demo,可以无脑化运行,先上最终的效果图:
 代码中我已经加了详细注释,文末有GitHub链接。

完整代码

# #------------------------------------------------------------#
# 可视化Detr方法:
# spatial attention weight : (cq + oq)*pk
# combined attention weight: (cq + oq)*(memory + pk)
# 其中:
#     pk:原始特征图的位置编码;
#     oq:训练好的object queries
#     cq:decoder最后一层self-attn中的输出query
#     memory:encoder的输出
# #------------------------------------------------------------#
# 在此基础上只要稍微修改便可可视化ConditionalDetr的Fig1特征图
# #------------------------------------------------------------#
# 代码参考自:https://github.com/facebookresearch/detr/tree/colab
# #------------------------------------------------------------#import math
import numpy as npfrom PIL import Image
import requests
import matplotlib.pyplot as pltimport ipywidgets as widgets
from IPython.display import display, clear_outputimport torch
from torch import nn
from torchvision.models import resnet50
import torchvision.transforms as T
from torch.nn.functional import dropout,linear,softmax
torch.set_grad_enabled(False)def box_cxcywh_to_xyxy(x):x_c, y_c, w, h = x.unbind(1)b = [(x_c - 0.5 * w), (y_c - 0.5 * h),(x_c + 0.5 * w), (y_c + 0.5 * h)]return torch.stack(b, dim=1)def rescale_bboxes(out_bbox, size):img_w, img_h = sizeb = box_cxcywh_to_xyxy(out_bbox)b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)return b# COCO classes
CLASSES = ['N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus','train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A','stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse','sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack','umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis','snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove','skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass','cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich','orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake','chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A','N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard','cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A','book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier','toothbrush'
]
# colors for visualization
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]# standard PyTorch mean-std input image normalization
transform = T.Compose([T.Resize(800),T.ToTensor(),T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])# 加载线上的模型
model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True)
model.eval()
# 获取训练好的参数
for name, parameters in model.named_parameters():# 获取训练好的object queries,即pq:[100,256]if name == 'query_embed.weight':pq = parameters# 获取解码器的最后一层的交叉注意力模块中q和k的线性权重和偏置:[256*3,256],[768]if name == 'transformer.decoder.layers.5.multihead_attn.in_proj_weight':in_proj_weight = parametersif name == 'transformer.decoder.layers.5.multihead_attn.in_proj_bias':in_proj_bias = parameters
# 线上下载图像
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
im = Image.open(requests.get(url, stream=True).raw)
# img_path = '/home/wujian/000000039769.jpg'
# im = Image.open(img_path)# mean-std normalize the input image (batch-size: 1)
img = transform(im).unsqueeze(0)# propagate through the model
outputs = model(img)# keep only predictions with 0.7+ confidence
probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
keep = probas.max(-1).values > 0.9# convert boxes from [0; 1] to image scales
bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)# use lists to store the outputs via up-values
conv_features, enc_attn_weights, dec_attn_weights = [], [], []
cq = []     # 存储detr中的 cq
pk =  []    # 存储detr中的 encoder pos
memory = [] # 存储encoder的输出特征图memory# 注册hook
hooks = [# 获取resnet最后一层特征图model.backbone[-2].register_forward_hook(lambda self, input, output: conv_features.append(output)),# 获取encoder的图像特征图memorymodel.transformer.encoder.register_forward_hook(lambda self, input, output: memory.append(output)),# 获取encoder的最后一层layer的self-attn weightsmodel.transformer.encoder.layers[-1].self_attn.register_forward_hook(lambda self, input, output: enc_attn_weights.append(output[1])),# 获取decoder的最后一层layer中交叉注意力的 weightsmodel.transformer.decoder.layers[-1].multihead_attn.register_forward_hook(lambda self, input, output: dec_attn_weights.append(output[1])),# 获取decoder最后一层self-attn的输出cqmodel.transformer.decoder.layers[-1].norm1.register_forward_hook(lambda self, input, output: cq.append(output)),# 获取图像特征图的位置编码pkmodel.backbone[-1].register_forward_hook(lambda self, input, output: pk.append(output)),
]# propagate through the model
outputs = model(img)# 用完的hook后删除
for hook in hooks:hook.remove()# don't need the list anymore
conv_features = conv_features[0]       # [1,2048,25,34]
enc_attn_weights = enc_attn_weights[0] # [1,850,850]   : [N,L,S]
dec_attn_weights = dec_attn_weights[0] # [1,100,850]   : [N,L,S] --> [batch, tgt_len, src_len]
memory = memory[0] # [850,1,256]cq = cq[0]    # decoder的self_attn:最后一层输出[100,1,256]
pk = pk[0]    # [1,256,25,34]# 绘制postion embedding
pk = pk.flatten(-2).permute(2,0,1)           # [1,256,850] --> [850,1,256]
pq = pq.unsqueeze(1).repeat(1,1,1)           # [100,1,256]
q = pq + cq
#------------------------------------------------------#
#   1) k = pk,则可视化: (cq + oq)*pk
#   2_ k = pk + memory,则可视化 (cq + oq)*(memory + pk)
#   读者可自行尝试
#------------------------------------------------------#
k = pk
# k = pk + memory
#------------------------------------------------------## 将q和k完成线性层的映射,代码参考自nn.MultiHeadAttn()
_b = in_proj_bias
_start = 0
_end = 256
_w = in_proj_weight[_start:_end, :]
if _b is not None:_b = _b[_start:_end]
q = linear(q, _w, _b)_b = in_proj_bias
_start = 256
_end = 256 * 2
_w = in_proj_weight[_start:_end, :]
if _b is not None:_b = _b[_start:_end]
k = linear(k, _w, _b)scaling = float(256) ** -0.5
q = q * scaling
q = q.contiguous().view(100, 8, 32).transpose(0, 1)
k = k.contiguous().view(-1, 8, 32).transpose(0, 1)
attn_output_weights = torch.bmm(q, k.transpose(1, 2))attn_output_weights = attn_output_weights.view(1, 8, 100, 850)
attn_output_weights = attn_output_weights.view(1 * 8, 100, 850)
attn_output_weights = softmax(attn_output_weights, dim=-1)
attn_output_weights = attn_output_weights.view(1, 8, 100, 850)# 后续可视化各个头
attn_every_heads = attn_output_weights # [1,8,100,850]
attn_output_weights = attn_output_weights.sum(dim=1) / 8 # [1,100,850]#-----------#
#   可视化
#-----------#
# get the feature map shape
h, w = conv_features['0'].tensors.shape[-2:]fig, axs = plt.subplots(ncols=len(bboxes_scaled), nrows=10, figsize=(22, 28))  # [11,2]
colors = COLORS * 100# 可视化
for idx, ax_i, (xmin, ymin, xmax, ymax) in zip(keep.nonzero(), axs.T, bboxes_scaled):# 可视化decoder的注意力权重ax = ax_i[0]ax.imshow(dec_attn_weights[0, idx].view(h, w))ax.axis('off')ax.set_title(f'query id: {idx.item()}',fontsize = 30)# 可视化框和类别ax = ax_i[1]ax.imshow(im)ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,fill=False, color='blue', linewidth=3))ax.axis('off')ax.set_title(CLASSES[probas[idx].argmax()],fontsize = 30)# 分别可视化8个头部的位置特征图for head in range(2, 2 + 8):ax = ax_i[head]ax.imshow(attn_every_heads[0, head-2, idx].view(h,w))ax.axis('off')ax.set_title(f'head:{head-2}',fontsize = 30)
fig.tight_layout()        # 自动调整子图来使其填充整个画布
plt.show()

GitHub链接

https://github.com/wulele2/Detr-heat-map-visualization,给个star吧,太不容易了。若有问题欢迎+vx:wulele2541612007,拉你进群探讨交流。

附录:可视化resnet50最后一层特征图

'''
代码来源于facebook_detr
'''
#导入包
import requests
from PIL import Image
import matplotlib.pyplot as pltimport torch
import torch.nn as nn
from torchvision.models import resnet50
import torchvision.transforms as T
torch.set_grad_enabled(False)# 获取一张图像
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
im = Image.open(requests.get(url, stream=True).raw)plt.imshow(im)
plt.show()# 构造图像变换
transform = T.Compose([T.Resize(800),              # 将图像进行Resize,符合短边变换原则T.ToTensor(),               # 将[0,255] --> [0,1]之间的张量T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 给每个通道进行归一化
])# PLI --> tensor
img = transform(im).unsqueeze(0) # [1,3,h,w]# 构造模型
model = resnet50(pretrained=True)# 创建一个list存储特征图特征
fms = []
# 定制hook
def hook(module, input, output):fms.append(output)
# 注册hook
handle = model.layer4.register_forward_hook(hook)#forward
model(img)# 用完后删除hook
handle.remove()# 可视化
plt.figure(figsize=(16, 10))   # 画布大小
ax = plt.gca()                 # 获取坐标轴
ax.imshow(fms[0].squeeze(0)[0])# 可视化第一个channel
ax.axis('off')                 # 关闭坐标轴
plt.show()                     # 展示


Detr空间注意力热图及语义注意力热图可视化相关推荐

  1. ICCV 2019 | 基于关联语义注意力模型的图像修复

    作者丨薛洁婷 学校丨北京交通大学硕士生 研究方向丨图像翻译 论文引入 图像修复问题的关键是维持修复后图像的全局语义一致性以及破损区域的细节纹理合理性.近期关于图像修复问题的研究主要集中于通过使用空间注 ...

  2. 用双注意力模块来做语义分割

    作者|Umer Rasheed  编译|ronghuaiyang 导读 本文对双注意网络进行场景分割进行简要概述. 论文链接:https://arxiv.org/abs/1809.02983 图1,双 ...

  3. 特征图注意力_从数据结构到算法:图网络方法初探

    作者 | 朱梓豪 来源 | 机器之心 原文 | 从数据结构到算法:图网络方法初探 如果说 2019 年机器学习领域什么方向最火,那么必然有图神经网络的一席之地.其实早在很多年前,图神经网络就以图嵌入. ...

  4. 图学习——04.HAN(异构图注意力网络)

    HAN(Heterogeneous Graph Attention Network) 包含不同类型节点和连接的异构图 异构图的定义 定义如下图 V代表顶点,A是顶点所属的类别,ε代表边,R是边所属的类 ...

  5. 特征图注意力_CBAM:卷积块注意力模块

    此篇文章内容源自 CBAM: Convolutional Block Attention Module,若侵犯版权,请告知本人删帖. 原论文下载地址: https://arxiv.org/pdf/18 ...

  6. Pytorch:Transformer(Encoder编码器-Decoder解码器、多头注意力机制、多头自注意力机制、掩码张量、前馈全连接层、规范化层、子层连接结构、pyitcast) part1

    日萌社 人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新) Encoder编码器-Decoder解码器框架 + Atten ...

  7. CVPR 2019 开源论文 | 基于空间自适应归一化的图像语义合成

    作者丨武广 学校丨合肥工业大学硕士生 研究方向丨图像生成 深度学习在算力的推动下不断的发展,随着卷积层的堆叠,模型的层数是越来越深,理论上神经网络中的参数越多这样对数据的拟合和分布描述就能越细致.然而 ...

  8. 论文浅尝 | 面向视觉常识推理的层次语义增强方向图网络

    论文笔记整理:刘克欣,天津大学硕士 链接:https://dl.acm.org/doi/abs/10.1145/3475731.3484957 动机 视觉常识推理(VCR)任务旨在促进认知水平相关推理 ...

  9. 图 子类 数据库_构造知识图的语义模型

    幼稚园见解, 使用基于图的结构捕获数据源的语义 如果您是知识图和相关概念的新手,例如从数据源到本体的映射,我邀请您阅读以下入门文章. 知识图(KG)是用于捕获和构建大量多关系数据的有效工具,可以通过查 ...

最新文章

  1. ubantu14下vim的配置...
  2. static unsigned short,int ,char
  3. ROS2学习(十二).ROS概念 - RQt工具的使用
  4. 使用 Drone 构建 Coding 项目
  5. linux subversion 根目录检出,经验总结:详解Linux下Subversion的安装配置记录 下
  6. 专业英语笔记(Line Feed and Type Conversion)
  7. crop和resize操作区别
  8. sqk,按分钟统计平均值
  9. QML Label/Text 文本居中显示
  10. Android如何查看手机网卡信息和ip信息
  11. 谈谈EventTime以及Watermark
  12. 不填写内容用哪个斜杠代替_反斜杠
  13. 学UI设计,用对这5款设计软件是关键
  14. 简单快速!分享给你一款在线jpg格式转换器
  15. xp计算机管理下的服务显示不出来,WinXP系统任务栏不显示打开窗口的三种解决方案...
  16. 三国无双之雄霸天下java下载,三国之雄霸天下
  17. LabVIEW调用第三方exe软件或操作操作控制第三方软件界面的控件,如操控烧录软件
  18. spider mysql_MySQL中间件Spider引擎初探
  19. Cisco交换机路由器密码破解
  20. 搜寻马航MH370有多烧钱?澳方花费惊人

热门文章

  1. 全媒体运营 之 受众分析
  2. NTP时间服务器同步时钟系统安装汇总分享
  3. ionic2 tabs 自定义图标
  4. 辛小鬼家的Python基础语法
  5. ABP-oracle多数据库配置
  6. KNN实现手写数字识别(Python-OpenCV)
  7. 【23考研】计算机择校信息库—北京高校计算机相关专业22专业目录分类汇总(按专业课分类汇总)
  8. ImageLoader全局类配置 及图片展示配置(自定义缓存目录SD卡根目录)
  9. android+点击屏幕隐藏键盘,Android 显示和隐藏软键盘的方法(手动)
  10. 编译java程序时用于指定生成class_(15 )在编译 Java 程序时,用于指定生成 .class 文件位置的选项是A ) -g B ) -d C ) -verbose D...