RFBNet模型推理过程

import argparse
import cv2
import numpy as np
import os
import torch
import torch.backends.cudnn as cudnn
from torch.autograd import Variablefrom data import BaseTransform, VOC_300
from data import VOC_CLASSES as labelmap
from layers.functions import Detect, PriorBox
from utils.timer import Timerparser = argparse.ArgumentParser(description='Receptive Field Block Net')parser.add_argument('-v', '--version', default='RFB_vgg',help='RFB_vgg ,RFB_E_vgg or RFB_mobile version.')
parser.add_argument('-s', '--size', default='300',help='300 or 512 input size.')
parser.add_argument('-d', '--dataset', default='VOC',help='VOC or COCO version')
parser.add_argument('-m', '--trained_model', default=r'weights/7690.pth',type=str, help='Trained state_dict file path to open')
parser.add_argument('--save_folder', default='eval/', type=str,help='Dir to save results')
parser.add_argument('--cuda', default=False, type=bool,help='Use cuda to train model')
parser.add_argument('--retest', default=False, type=bool,help='test cache results')
args = parser.parse_args()if not os.path.exists(args.save_folder):os.mkdir(args.save_folder)cfg = VOC_300,if args.version == 'RFB_vgg':from models.RFB_Net_vgg import build_net
elif args.version == 'RFB_E_vgg':from models.RFB_Net_E_vgg import build_netpriorbox = PriorBox(cfg)
with torch.no_grad():priors = priorbox.forward()if args.cuda:priors = priors.cuda()def py_cpu_nms(dets, thresh):"""Pure Python NMS baseline."""x1 = dets[:, 0]y1 = dets[:, 1]x2 = dets[:, 2]y2 = dets[:, 3]scores = dets[:, 4]areas = (x2 - x1 + 1) * (y2 - y1 + 1)order = scores.argsort()[::-1]keep = []while order.size > 0:i = order[0]keep.append(i)xx1 = np.maximum(x1[i], x1[order[1:]])yy1 = np.maximum(y1[i], y1[order[1:]])xx2 = np.minimum(x2[i], x2[order[1:]])yy2 = np.minimum(y2[i], y2[order[1:]])w = np.maximum(0.0, xx2 - xx1 + 1)h = np.maximum(0.0, yy2 - yy1 + 1)inter = w * hovr = inter / (areas[i] + areas[order[1:]] - inter)inds = np.where(ovr <= thresh)[0]order = order[inds + 1]return keepclass ObjectDetector:def __init__(self, net, detection, transform, num_classes=21, cuda=False, max_per_image=300, thresh=0.005):self.net = netself.detection = detectionself.transform = transformself.max_per_image = 300self.num_classes = num_classesself.max_per_image = max_per_imageself.cuda = cudaself.thresh = threshdef predict(self, img):scale = torch.Tensor([img.shape[1], img.shape[0],img.shape[1], img.shape[0]]).cpu().numpy()_t = {'im_detect': Timer(), 'misc': Timer()}assert img.shape[2] == 3with torch.no_grad():x = transform(img).unsqueeze(0)if self.cuda:x = x.cuda()scale = scale.cuda()_t['im_detect'].tic()out = net(x)  # forward passboxes, scores = self.detection.forward(out, priors)detect_time = _t['im_detect'].toc()boxes = boxes[0]scores = scores[0]boxes = boxes.cpu().numpy()scores = scores.cpu().numpy()# scale each detection back up to the imageboxes *= scale_t['misc'].tic()all_boxes = [[] for _ in range(num_classes)]for j in range(1, num_classes):inds = np.where(scores[:, j] > self.thresh)[0]if len(inds) == 0:all_boxes[j] = np.zeros([0, 5], dtype=np.float32)continuec_bboxes = boxes[inds]c_scores = scores[inds, j]#print(scores[:, j])c_dets = np.hstack((c_bboxes, c_scores[:, np.newaxis])).astype(np.float32, copy=False)# keep = nms(c_bboxes,c_scores)keep = py_cpu_nms(c_dets, 0.45)keep = keep[:self.max_per_image]c_dets = c_dets[keep, :]all_boxes[j] = c_detsif self.max_per_image > 0:image_scores = np.hstack([all_boxes[j][:, -1] for j in range(1, num_classes)])if len(image_scores) > self.max_per_image:image_thresh = np.sort(image_scores)[-self.max_per_image]for j in range(1, num_classes):keep = np.where(all_boxes[j][:, -1] >= image_thresh)[0]all_boxes[j] = all_boxes[j][keep, :]nms_time = _t['misc'].toc()# print('net time: ', detect_time)# print('post time: ', nms_time)print('time: ', detect_time + nms_time)return all_boxesCOLORS = [(255, 0, 0), (0, 255, 0), (0, 0, 255)]
FONT = cv2.FONT_HERSHEY_SIMPLEXif __name__ == '__main__':# load netimg_dim = 300num_classes = 2net = build_net('test', num_classes)  # initialize detectorstate_dict = torch.load(args.trained_model, map_location='cpu')# create new OrderedDict that does not contain `module.`from collections import OrderedDictnew_state_dict = OrderedDict()for k, v in state_dict.items():head = k[:7]if head == 'module.':name = k[7:]  # remove `module.`else:name = knew_state_dict[name] = vnet.load_state_dict(new_state_dict)net.eval()print('Finished loading model!')print(net)# load dataif args.cuda:net = net.cuda()cudnn.benchmark = True# evaluationtop_k = 300detector = Detect(num_classes, 0, cfg)rgb_means = (104, 117, 123)rgb_std = (1, 1, 1)transform = BaseTransform(img_dim, rgb_means, (2, 0, 1))object_detector = ObjectDetector(net, detector, transform)# cap = cv2.VideoCapture(0)# while True:##     ret, image = cap.read()#     detect_bboxes = object_detector.predict(image)#     for class_id, class_collection in enumerate(detect_bboxes):#         if len(class_collection) > 0:#             for i in range(class_collection.shape[0]):#                 if class_collection[i, -1] > 0.6:#                     pt = class_collection[i]#                     cv2.rectangle(image, (int(pt[0]), int(pt[1])), (int(pt[2]),#                                                                     int(pt[3])), COLORS[i % 3], 3)#                     cv2.putText(image, labelmap[class_id], (int(pt[0]), int(pt[1])), FONT,#                                 0.5, (255, 255, 255), 2)#     cv2.imshow('result', image)#     cv2.waitKey(10)image = cv2.imread('eval/5.jpg')detect_bboxes = object_detector.predict(image)for class_id,class_collection in enumerate(detect_bboxes):if len(class_collection)>0:for i in range(class_collection.shape[0]):if class_collection[i,-1]>0.6:pt = class_collection[i]cv2.rectangle(image, (int(pt[0]), int(pt[1])), (int(pt[2]),int(pt[3])), COLORS[i % 3], 2)cv2.putText(image, labelmap[class_id], (int(pt[0]), int(pt[1])), FONT,0.5, (255, 255, 255), 2)cv2.imshow('result',image)cv2.waitKey()

RFBNet模型推理相关推荐

  1. MindSpore模型推理

    MindSpore模型推理 如果想在应用中使用自定义的MindSpore Lite模型,需要告知推理器模型所在的位置.推理器加载模型的方式有以下三种: • 加载本地模型. • 加载远程模型. • 混合 ...

  2. Python时间序列模型推理预测实战:时序推理数据预处理(特征生成、lstm输入结构组织)、模型加载、模型预测结果保存、条件判断模型循环运行

    Python时间序列模型推理预测实战:时序推理数据预处理(特征生成.lstm输入结构组织).模型加载.模型预测结果保存.条件判断模型循环运行 目录

  3. Python使用tkinter构建一个多元回归预测模型GUI界面(接受用户输入数据并给出模型推理结果)

    Python使用tkinter构建一个多元回归预测模型GUI界面(接受用户输入数据并给出模型推理结果) 目录

  4. Python基于statsmodels包构建多元线性回归模型:模型构建、模型解析、模型推理预测

    Python基于statsmodels包构建多元线性回归模型:模型构建.模型解析.模型推理预测 目录

  5. ML之回归预测:利用Lasso、ElasticNet、GBDT等算法构建集成学习算法AvgModelsR对国内某平台上海2020年6月份房价数据集【12+1】进行回归预测(模型评估、模型推理)

    ML之回归预测:利用Lasso.ElasticNet.GBDT等算法构建集成学习算法AvgModelsR对国内某平台上海2020年6月份房价数据集[12+1]进行回归预测(模型评估.模型推理) 目录 ...

  6. 千元显卡玩转百亿大模型,清华推出工具包BMInf让模型推理轻而易举

    最近在工业界与学术界,最热门的方向莫过于预训练语言模型.而具有百亿乃至千亿参数的大规模预训练语言模型,更是业界与学术界发力的热点. 但现在大模型的应用却有着较高的门槛,排队申请或需要付费的API.较长 ...

  7. 使用tensoflow serving来部署模型推理节点

    使用tensoflow serving来部署模型推理节点 这里使用的时docker来进行模型的部署,主要是docker更轻便和方便. 1.训练一个分类模型 分类模型一般很简单,这里我已经训练好一个(测 ...

  8. onnx模型推理(python)

    onnx模型推理(python) 以下ONNX一个检测模型的推理过程,其他模型稍微修改即可 # -*-coding: utf-8 -*-import os, syssys.path.append(os ...

  9. XEngine:深度学习模型推理优化

    摘要:从显存优化,计算优化两个方面来分析一下如何进行深度学习模型推理优化. 本文分享自华为云社区<XEngine-深度学习推理优化>,作者: ross.xw. 前言 深度学习模型的开发周期 ...

  10. 深度模型推理在腾讯游戏的应用与实践(王者荣耀、和平精英等均有应用)

    猜你喜欢 0.图模型在信息流推荐系统中的原理和实践1.如何搭建一套个性化推荐系统?2.从零开始搭建创业公司后台技术栈3.全民K歌推荐系统算法.架构及后台实现4.微博推荐算法实践与机器学习平台演进5.腾 ...

最新文章

  1. 2020 年技术趋势一览:AutoML、联邦学习、云寡头时代的终结
  2. Hive SemanticException
  3. oracle 复制数据 insert into、as select
  4. Pandas——筛选数据(loc、iloc)
  5. 伺服电机转矩常数的标定方法
  6. 怎么用linux给苹果手机降级,【教程】iPhone降级_iPhone系统怎么降级_手机中国
  7. UE4 相对坐标转世界坐标
  8. 基于Android studio的WIFI搜索显示与WIFI打开
  9. 【开发环境】 Ubuntu14.04 安装Skyeye 1.3.5过程
  10. 无法启动此程序,因为计算机中丢失MSVCP120.dll文件、应用程序无法正常启动0xc000007b
  11. linux中tomcat部署项目步骤以及命令
  12. 线性链表实现对二进制数加1运算
  13. 【MySQL】SQL优化
  14. Java高级开发必备--Docker进阶(一篇详细教程,进阶Docker)
  15. Ubuntu 蓝屏拯救
  16. 试题 算法训练 生活大爆炸版石头剪刀布
  17. [EULAR文摘] 超声腱鞘炎对RA早期诊断的价值
  18. 解决PDapp占据C盘空间的方法
  19. 小程序源码:百变头像框制作微信小程序源码下载,免服务器和域名
  20. 有温度传感器的风机控制系统C语言,基于单片机的暖风机的设计任务书、开题报告...

热门文章

  1. PCB绘制成长日记1
  2. java jui_急求用带jui界面写的java聊天程序!!!
  3. 装配uwsgi和nginx rabbitMQ
  4. 计算机 电脑 整机 加密,如何加密文件夹
  5. 2进制 , 8进制 , 10进制 , 16进制 , 介绍 及 相互转换 及 快速转换的方法
  6. 外接USB蓝牙设置无法启动
  7. 微软Kinect是怎么做到的
  8. Excel文档中字符型数据转化为数字类型
  9. adb 切换usb模式_adb调试命令,adb强制打开usb调试,adb命令打开usb调试
  10. AXI协议中的BURST