文章目录

  • 前言
  • 一、jupyter notebook运行代码
  • 二、Pycharm(非jupyter notebook的IDE)
  • 总结

前言

一篇文章结尾看到很多同学没有找到DETR的预测代码,这里我分享一下官方写的DETR的预测代码,其实就是github上DETR官方写的一个Jupyter notebook,可能需要梯子才能访问,这里我贴出来。因为DETR比较难训练,我觉得可以用官方的预训练模型来看看效果。


如果要用自己训练的模型来推理,可以参考我的这篇文章:DETR训练自己的数据集


顺便说一下,这个模型是简易版,其实就是DETR论文最后一页贴出来的代码,mAP大概要比源码的低2-3个点,但已经很强了。只用了50行代码就把DETR模型架构展示出来了,真的是非常简洁的一个模型!

一、jupyter notebook运行代码

如果用的是jupyter notebook,用这个代码,修改代码末尾的图片路径即可

from PIL import Image
import matplotlib.pyplot as plt
%config InlineBackend.figure_format = 'svg'
import ipywidgets as widgets
from IPython.display import display, clear_output
import cv2
import torch
from torch import nn
from torchvision.models import resnet50
import torchvision.transforms as T
torch.set_grad_enabled(False)
from torchvision import models# COCO classes
CLASSES = ['N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus','train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A','stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse','sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack','umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis','snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove','skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass','cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich','orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake','chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A','N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard','cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A','book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier','toothbrush'
]# colors for visualization
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]# standard PyTorch mean-std input image normalization
transform = T.Compose([T.Resize(800),T.ToTensor(),T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])# for output bounding box post-processing
def box_cxcywh_to_xyxy(x):x_c, y_c, w, h = x.unbind(1)b = [(x_c - 0.5 * w), (y_c - 0.5 * h),(x_c + 0.5 * w), (y_c + 0.5 * h)]return torch.stack(b, dim=1)def rescale_bboxes(out_bbox, size):img_w, img_h = sizeb = box_cxcywh_to_xyxy(out_bbox)b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)return bdef plot_results(pil_img, prob, boxes):plt.figure(figsize=(16,10))plt.imshow(pil_img)ax = plt.gca()colors = COLORS * 100for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,fill=False, color=c, linewidth=3))cl = p.argmax()text = f'{CLASSES[cl]}:{p[cl]:0.2f}'ax.text(xmin, ymin, text, fontsize=15,bbox=dict(facecolor='yellow', alpha=0.5))plt.axis('off')plt.show()def detect(im, model, transform):# mean-std normalize the input image (batch-size: 1)img = transform(im).unsqueeze(0)# demo model only support by default images with aspect ratio between 0.5 and 2# if you want to use images with an aspect ratio outside this range# rescale your image so that the maximum size is at most 1333 for best resultsassert img.shape[-2] <= 1600 and img.shape[-1] <= 1600, 'demo model only supports images up to 1600 pixels on each side'# propagate through the modeloutputs = model(img)# keep only predictions with 0.7+ confidenceprobas = outputs['pred_logits'].softmax(-1)[0, :, :-1]keep = probas.max(-1).values > 0.7# convert boxes from [0; 1] to image scalesbboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)return probas[keep], bboxes_scaledclass DETRdemo(nn.Module):"""Demo DETR implementation.Demo implementation of DETR in minimal number of lines, with thefollowing differences wrt DETR in the paper:* learned positional encoding (instead of sine)* positional encoding is passed at input (instead of attention)* fc bbox predictor (instead of MLP)The model achieves ~40 AP on COCO val5k and runs at ~28 FPS on Tesla V100.Only batch size 1 supported."""def __init__(self, num_classes, hidden_dim=256, nheads=8,num_encoder_layers=6, num_decoder_layers=6):super().__init__()# create ResNet-50 backboneself.backbone = resnet50()del self.backbone.fc# create conversion layerself.conv = nn.Conv2d(2048, hidden_dim, 1)# create a default PyTorch transformerself.transformer = nn.Transformer(hidden_dim, nheads, num_encoder_layers, num_decoder_layers)# prediction heads, one extra class for predicting non-empty slots# note that in baseline DETR linear_bbox layer is 3-layer MLPself.linear_class = nn.Linear(hidden_dim, num_classes + 1)self.linear_bbox = nn.Linear(hidden_dim, 4)# output positional encodings (object queries)self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))# spatial positional encodings# note that in baseline DETR we use sine positional encodingsself.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))def forward(self, inputs):# propagate inputs through ResNet-50 up to avg-pool layerx = self.backbone.conv1(inputs)x = self.backbone.bn1(x)x = self.backbone.relu(x)x = self.backbone.maxpool(x)x = self.backbone.layer1(x)x = self.backbone.layer2(x)x = self.backbone.layer3(x)x = self.backbone.layer4(x)# convert from 2048 to 256 feature planes for the transformerh = self.conv(x)# construct positional encodingsH, W = h.shape[-2:]pos = torch.cat([self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),], dim=-1).flatten(0, 1).unsqueeze(1)# propagate through the transformerh = self.transformer(pos + 0.1 * h.flatten(2).permute(2, 0, 1),self.query_pos.unsqueeze(1)).transpose(0, 1)# finally project transformer outputs to class labels and bounding boxesreturn {'pred_logits': self.linear_class(h), 'pred_boxes': self.linear_bbox(h).sigmoid()}if __name__=="__main__":detr = DETRdemo(num_classes=91)state_dict = torch.hub.load_state_dict_from_url(url='https://dl.fbaipublicfiles.com/detr/detr_demo-da2a99e9.pth',map_location='cpu', check_hash=True)detr.load_state_dict(state_dict)detr.eval()# 这里修改自己的图片路径即可,注意路径不要有中文,否则要加一行解码的代码im = cv2.imread("test4.JPG")im= cv2.cvtColor(im, cv2.COLOR_BGR2RGB)im=Image.fromarray(im)scores, boxes= detect(im, detr, transform)plot_results(im, scores, boxes)

这是用我自己拍的图片跑的效果:

二、Pycharm(非jupyter notebook的IDE)

from PIL import Image
import matplotlib.pyplot as plt
import cv2
import torch
from torch import nn
from torchvision.models import resnet50
import torchvision.transforms as Ttorch.set_grad_enabled(False)# COCO classes
CLASSES = ['N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus','train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A','stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse','sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack','umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis','snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove','skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass','cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich','orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake','chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A','N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard','cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A','book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier','toothbrush'
]# colors for visualization
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]# standard PyTorch mean-std input image normalization
transform = T.Compose([T.Resize(800),T.ToTensor(),T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])# for output bounding box post-processing
def box_cxcywh_to_xyxy(x):x_c, y_c, w, h = x.unbind(1)b = [(x_c - 0.5 * w), (y_c - 0.5 * h),(x_c + 0.5 * w), (y_c + 0.5 * h)]return torch.stack(b, dim=1)def rescale_bboxes(out_bbox, size):img_w, img_h = sizeb = box_cxcywh_to_xyxy(out_bbox)b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)return bdef plot_results(pil_img, prob, boxes):plt.figure(figsize=(16, 10))plt.imshow(pil_img)ax = plt.gca()colors = COLORS * 100for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,fill=False, color=c, linewidth=3))cl = p.argmax()text = f'{CLASSES[cl]}:{p[cl]:0.2f}'ax.text(xmin, ymin, text, fontsize=15,bbox=dict(facecolor='yellow', alpha=0.5))plt.axis('off')plt.show()def detect(im, model, transform):# mean-std normalize the input image (batch-size: 1)img = transform(im).unsqueeze(0)# demo model only support by default images with aspect ratio between 0.5 and 2# if you want to use images with an aspect ratio outside this range# rescale your image so that the maximum size is at most 1333 for best resultsassert img.shape[-2] <= 1600 and img.shape[-1] <= 1600, 'demo model only supports images up to 1600 pixels on each side'# propagate through the modeloutputs = model(img)# keep only predictions with 0.7+ confidenceprobas = outputs['pred_logits'].softmax(-1)[0, :, :-1]keep = probas.max(-1).values > 0.7# convert boxes from [0; 1] to image scalesbboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)return probas[keep], bboxes_scaledclass DETRdemo(nn.Module):"""Demo DETR implementation.Demo implementation of DETR in minimal number of lines, with thefollowing differences wrt DETR in the paper:* learned positional encoding (instead of sine)* positional encoding is passed at input (instead of attention)* fc bbox predictor (instead of MLP)The model achieves ~40 AP on COCO val5k and runs at ~28 FPS on Tesla V100.Only batch size 1 supported."""def __init__(self, num_classes, hidden_dim=256, nheads=8,num_encoder_layers=6, num_decoder_layers=6):super().__init__()# create ResNet-50 backboneself.backbone = resnet50()del self.backbone.fc# create conversion layerself.conv = nn.Conv2d(2048, hidden_dim, 1)# create a default PyTorch transformerself.transformer = nn.Transformer(hidden_dim, nheads, num_encoder_layers, num_decoder_layers)# prediction heads, one extra class for predicting non-empty slots# note that in baseline DETR linear_bbox layer is 3-layer MLPself.linear_class = nn.Linear(hidden_dim, num_classes + 1)self.linear_bbox = nn.Linear(hidden_dim, 4)# output positional encodings (object queries)self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))# spatial positional encodings# note that in baseline DETR we use sine positional encodingsself.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))def forward(self, inputs):# propagate inputs through ResNet-50 up to avg-pool layerx = self.backbone.conv1(inputs)x = self.backbone.bn1(x)x = self.backbone.relu(x)x = self.backbone.maxpool(x)x = self.backbone.layer1(x)x = self.backbone.layer2(x)x = self.backbone.layer3(x)x = self.backbone.layer4(x)# convert from 2048 to 256 feature planes for the transformerh = self.conv(x)# construct positional encodingsH, W = h.shape[-2:]pos = torch.cat([self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),], dim=-1).flatten(0, 1).unsqueeze(1)# propagate through the transformerh = self.transformer(pos + 0.1 * h.flatten(2).permute(2, 0, 1),self.query_pos.unsqueeze(1)).transpose(0, 1)# finally project transformer outputs to class labels and bounding boxesreturn {'pred_logits': self.linear_class(h),'pred_boxes': self.linear_bbox(h).sigmoid()}if __name__ == "__main__":detr = DETRdemo(num_classes=91)state_dict = torch.hub.load_state_dict_from_url(url='https://dl.fbaipublicfiles.com/detr/detr_demo-da2a99e9.pth',map_location='cpu', check_hash=True)detr.load_state_dict(state_dict)detr.eval()im = cv2.imread("test4.JPG")im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)im = Image.fromarray(im)scores, boxes = detect(im, detr, transform)plot_results(im, scores, boxes)

总结

这里再贴以下官方github:DETR
真的是非常贴心的开源代码!

【DETR】DETR预测/推理代码相关推荐

  1. 计算机视觉算法——基于Transformer的目标检测(DETR / Deformable DETR / DETR 3D)

    计算机视觉算法--基于Transformer的目标检测(DETR / Deformable DETR / DETR 3D) 计算机视觉算法--基于Transformer的目标检测(DETR / Def ...

  2. cnn 预测过程代码_FPN的Tensorflow代码详解——特征提取

    @TOC   特征金字塔网络最早于2017年发表于CVPR,与Faster RCNN相比其在多池度特征预测的方式使得其在小目标预测上取得了较好的效果.FPN也作为mmdeteciton的Neck模块, ...

  3. PointPillars点云检测在OpenPCDet推理代码详解

    在之前的文章中已经详细解析了PointPillars的论文和训练代码的实现和详解,可以参考之前的博客:PointPillars论文解析和OpenPCDet代码解析_NNNNNathan的博客-CSDN ...

  4. R语言机器学习Caret包(Caret包是分类和回归训练的简称)、数据划分、数据预处理、模型构建、模型调优、模型评估、多模型对比、模型预测推理

    R语言机器学习Caret包(Caret包是分类和回归训练的简称).数据划分.数据预处理.模型构建.模型调优.模型评估.多模型对比.模型预测推理 目录

  5. R语言使用caretEnsemble包的caretStack函数把多个机器学习模型融合成一个模型、构建融合(集成)预测模型、使用融合模型进行预测推理

    R语言使用caretEnsemble包的caretStack函数把多个机器学习模型融合成一个模型.构建融合(集成)预测模型.自定义融合模型的trainControl参数.method参数.评估指标参数 ...

  6. pandas基于时序数据计算模型预测推理需要的统计数据(累计时间、长度变化、变化率、方差、均值、最大、最小等):范围内的统计量、变化率、获得数据集最后的几条数据的统计量、变化率、获得范围内的统计量

    pandas基于时序数据计算模型预测推理需要的统计数据(累计时间.长度变化.变化率.方差.均值.最大.最小等):范围内的统计量.变化率.获得数据集最后的几条数据的统计量.变化率.获得范围内的统计量 目 ...

  7. R语言构建文本分类模型:文本数据预处理、构建词袋模型(bag of words)、构建xgboost文本分类模型、xgboost模型预测推理并使用混淆矩阵评估模型、可视化模型预测的概率分布

    R语言构建文本分类模型:文本数据预处理.构建词袋模型(bag of words).构建xgboost文本分类模型.xgboost模型预测推理并使用混淆矩阵评估模型.可视化模型预测的概率分布 目录

  8. pandas基于时序数据计算模型预测推理需要的统计数据(累计时间、长度变化、变化率、方差、均值、最大、最小等):数据持续的时间(分钟)、获得某一节点之后的数据总变化量、获得范围内的统计量

    pandas基于时序数据计算模型预测推理需要的统计数据(累计时间.长度变化.变化率.方差.均值.最大.最小等):数据持续的时间(分钟).获得某一节点之后的数据总变化量.获得范围内的统计量 目录

  9. R语言构建xgboost模型:使用xgboost的第一颗树(前N颗树)进行预测推理或者使用全部树进行预测推理、比较误分类率指标

    R语言构建xgboost模型:使用xgboost的第一颗树(前N颗树)进行预测推理或者使用全部树进行预测推理.比较误分类率指标 目录

最新文章

  1. 测试晶面间距软件_丽江导电橡胶电阻率测试仪生产商
  2. 右键新建里面没有word和excel_Windows10系统下如何将Sublime Text3添加到右键快捷菜单?...
  3. Android:problem opening wizard the selected wizard could not be started
  4. JavaScript新知:sessionStorage and localStorage
  5. WPF TabControl Unload俩次的解决方案
  6. 4个轮子+1部手机=长城眼里的智能汽车现状
  7. flask_socketio 用法:
  8. 使用freemarker生成java文件(其他文件也可以)
  9. 【Android 安全】DEX 加密 ( Proguard 简介 | Proguard 相关网址 | Proguard 混淆配置 )
  10. web前端知识点太多_web前端入门必学的16个知识点,都来看一下吧
  11. linux服务器拷贝目录文件夹,linux两台服务器之间文件/文件夹拷贝
  12. PHP排序算法的复习和总结
  13. 浙江交换机厂家带你全面了解什么是工业交换机?
  14. 1,2,2,3,3,4,4,4,......
  15. noip复赛电脑有excel吗_指南 | 现在就必须了解的信息学竞赛(高一学生)
  16. freeswitch 查看当前注册用户命令
  17. hdu 1251 字典树,指针版
  18. python 获取几小时之前,几分钟前,几天前,几个月前,及几年前的具体时间
  19. LXC与宿主机共享目录(七)
  20. 易语言注册机:界面绘制及皮肤模块的引用

热门文章

  1. 计算机组成原理——存储器
  2. 几个常用信号分解算法的python包
  3. 人脸识别闸机助力线下展览与演出
  4. Nordic NRF-52 深度解析
  5. Slicer学习笔记(四十六)slicer 常用的几个模块
  6. 操作手册 : AD 及 LDAP 操作
  7. matlab编程误差分析,基于MATLAB的圆度误差分析.pdf
  8. java生成两种二维码
  9. 密位测距离口诀_密位点瞄准镜怎么测距?
  10. 如何设计一个安全的对外接口