目标检测中NMS(non maximum suppression)
一、原理
参考网址
NMS即non maximum suppression即非极大抑制,顾名思义就是抑制不是极大值的元素,搜索局部的极大值。在最近几年常见的物体检测算法(包括rcnn、sppnet、fast-rcnn、faster-rcnn等)中,最终都会从一张图片中找出很多个可能是物体的矩形框,然后为每个矩形框为做类别分类概率。
Soft NMS法:对于得分小于阈值的边框,不再直接舍弃,而是降低其得分。Soft NMS是对NMS的优化算法,它在不增加额外参数的情况下且只需要对NMS算法进行简单的改动就能提高AP。该Soft-NMS算法在标准数据集PASCAL VOC2007(较R-FCN和Faster-RCNN提升1.7%)和MS-COCO(较R-FCN提升1.3%,较Faster-RCNN提升1.1%)上均有提升。对于大多数数据集而言,作用比较小,提升效果非常不明显,它起作用的地方是大量密集的同类重叠场景,大量密集的不同类重叠场景其实也没什么作用。
二、实现步骤
1、NMS算法实现
参考网址
假设我们已经有了预测的框,每个预测框对应的类别,每个预测框对应的类别得分。本文中使用的案例是n个[x1,y1,x2,y2,confident, 0,0,0,1,0,0] --> n个[左上坐标,右下坐标,置信度,种类的one-hot编码]。
![](/assets/blank.gif)
(1)对于每个类别而言,取该类中计算得分最大的框与其余预测框之间的IoU。
(2)根据设定的阈值,剔除掉该类中IOU大于阈值的预测框。
(3)对于每个类别而言,在(2)的基础上循环(1)-(2)步骤,在该类中已经排好序的情况下,直接下一个最大值作为该类的得分最大的框与其余预测框之间的IoU;直到剩余有用的预测框个数为0。
2、softNMS算法实现
论文地址:https://arxiv.org/pdf/1704.04503v2.pdf
(1)对于每个类别,按照类别得分的从大到小顺序排列
(2)对于每个类别而言,取该类中计算得分最大的框与其余预测框之间的IoU。
(3)根据计算的iou,使用高斯惩罚函数公式
![](/assets/blank.gif)
Si 表示置信度,当iou越大时,该预测框的置信度就会施加更重的惩罚值,从而降低置信度。
(4)对于每个类别而言,在(3)对置信度施加惩罚后,再剔除掉该类中置信度小于阈值的预测框。
(5)重复(1) - (4)步骤,直至剩余的预测框个数为0。
三、代码实现
1、Pytorch版
import numpy as np
import cv2
import matplotlib.pyplot as plt
import torchdef bbox_iou(box1, box2, x1y1x2y2=True):# 坐标分离if not x1y1x2y2:b1_x1, b1_x2 = box1[:, 0] - box1[:, 2]/2, box1[:, 0] + box1[:, 2]/2b1_y1, b1_y2 = box1[:, 1] - box1[:, 3]/2, box1[:, 1] + box1[:, 3]/2b2_x1, b2_x2 = box2[:, 0] - box2[:, 2]/2, box2[:, 0] + box2[:, 2]/2b2_y1, b2_y2 = box2[:, 1] - box2[:, 3]/2, box2[:, 1] + box2[:, 3]/2else:b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3]b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3]# 计算交集的两坐标(即交集的左上、右下坐标)inter_rect_x1 = torch.max(b1_x1, b2_x1)inter_rect_y1 = torch.max(b1_y1, b2_y1)inter_rect_x2 = torch.min(b1_x2, b2_x2)inter_rect_y2 = torch.min(b1_y2, b2_y2)# 计算交集面积inter_area = torch.clamp(inter_rect_x2 - inter_rect_x1, min=0) *\torch.clamp(inter_rect_y2 - inter_rect_y1, min=0)# 分别计算两个目标框各自的面积b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)# 交集/并集iou = inter_area / torch.clamp(b1_area + b2_area - inter_area, min= 1e-6)return ioudef nms(boxes, nms_thres=0.5):result = []boxes = torch.Tensor(boxes)#------------------------------------------## 获得预测结果中包含的所有种类#------------------------------------------#unique_labels = boxes[:, -1].cpu().unique()for c in unique_labels:#------------------------------------------## 获得某一类得分筛选后全部的预测结果#------------------------------------------#detections_class = boxes[boxes[:, -1] == c]# 按照存在物体的置信度排序_, conf_sort_index = torch.sort(detections_class[:, 4], descending=True)detections_class = detections_class[conf_sort_index]# 进行非极大抑制max_detections = []while detections_class.size(0):# 取出这一类置信度最高的,一步一步往下判断,判断重合程度是否大于nms_thres,如果是则去除掉max_detections.append(detections_class[0].unsqueeze(0))if len(detections_class) == 1:breakious = bbox_iou(max_detections[-1], detections_class[1:])detections_class = detections_class[1:][ious < nms_thres]# 堆叠max_detections = torch.cat(max_detections).dataresult.append(max_detections)result = torch.cat(result).data# ===================================================================================== ## 案例绘图 ## ===================================================================================== #plt.figure()colors = [(255,0,0),(0,255,0),(0,0,255)]img = np.zeros((616, 616, 3))plt.subplot(121)for i in result:cv2.rectangle(img, (int(i[0]), int(i[1])), (int(i[2]), int(i[3])), colors[int(i[5])], 2)plt.imshow(img)plt.title('nms')plt.subplot(122)img_ = np.zeros((616, 616, 3))for i in boxes:cv2.rectangle(img_, (int(i[0]), int(i[1])), (int(i[2]), int(i[3])), colors[int(i[5])], 2)plt.imshow(img_)plt.title('original')plt.show()# ===================================================================================== #return result # 返回nms后的结果def soft_nms(boxes,conf_thres=0.5,sigma=0.5):result = []boxes = torch.Tensor(boxes)#------------------------------------------## 获得预测结果中包含的所有种类#------------------------------------------#unique_labels = boxes[:, -1].cpu().unique()for c in unique_labels:#------------------------------------------## 获得某一类得分筛选后全部的预测结果#------------------------------------------#detections_class = boxes[boxes[:, -1] == c]# 按照存在物体的置信度排序_, conf_sort_index = torch.sort(detections_class[:, 4], descending=True)detections_class = detections_class[conf_sort_index]# 进行非极大抑制max_detections = []while detections_class.size(0):# 取出这一类置信度最高的,一步一步往下判断,根据iou设置一个对置信度的惩罚因子,去除置信度小于conf_thres的框max_detections.append(detections_class[0].unsqueeze(0))if len(detections_class) == 1:breakious = bbox_iou(max_detections[-1], detections_class[1:])detections_class[1:, 4] = torch.exp(-(ious * ious) / sigma) * detections_class[1:, 4]detections_class = detections_class[1:]detections_class = detections_class[detections_class[:, 4] >= conf_thres]arg_sort = torch.argsort(detections_class[:, 4], descending = True)detections_class = detections_class[arg_sort]# 堆叠max_detections = torch.cat(max_detections).dataresult.append(max_detections)result = torch.cat(result).data# ===================================================================================== ## 案例绘图 ## ===================================================================================== #plt.figure()colors = [(255,0,0),(0,255,0),(0,0,255)]img = np.zeros((616, 616, 3))plt.subplot(121)for i in result:cv2.rectangle(img, (int(i[0]), int(i[1])), (int(i[2]), int(i[3])), colors[int(i[5])], 2)plt.imshow(img)plt.title('nms')plt.subplot(122)img_ = np.zeros((616, 616, 3))for i in boxes:cv2.rectangle(img_, (int(i[0]), int(i[1])), (int(i[2]), int(i[3])), colors[int(i[5])], 2)plt.imshow(img_)plt.title('original')plt.show()# ===================================================================================== #return result # 返回nms后的结果if __name__ == '__main__':boxes = np.array([[100, 110, 210, 210, 0.71, 0.7, 0.3, 0.5],[250, 250, 420, 420, 0.8, 0.1, 0.8, 0.6],[220, 200, 320, 330, 0.92, 0.2, 0.5, 1.0],[120, 100, 210, 210, 0.72, 0.8, 0.2, 0.3],[230, 240, 325, 330, 0.81, 0.1, 0.9, 0.2],[220, 230, 315, 340, 0.91, 0.2, 0.7, 0.6]])# 转换成[x1,y1,x2,y2,confident,种类下标]boxes = np.hstack((boxes[...,:5],np.expand_dims(np.argmax(boxes[:,5:],axis=-1),axis=-1)))nms(boxes,nms_thres=0.5)soft_nms(boxes,0.5,0.1) #boxes,threshold,λ(gauss函数参数)
效果图:
![](/assets/blank.gif)
2、Numpy版
import numpy as np
import cv2
import matplotlib.pyplot as plt'''numpy版
'''
import numpy as np
def bbox_iou(box1, box2, x1y1x2y2=True):if not x1y1x2y2:b1_x1, b1_x2 = box1[:, 0] - box1[:, 2]/2, box1[:, 0] + box1[:, 2]/2b1_y1, b1_y2 = box1[:, 1] - box1[:, 3]/2, box1[:, 1] + box1[:, 3]/2b2_x1, b2_x2 = box2[:, 0] - box2[:, 2]/2, box2[:, 0] + box2[:, 2]/2b2_y1, b2_y2 = box2[:, 1] - box2[:, 3]/2, box2[:, 1] + box2[:, 3]/2else:b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3]b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3]inter_rect_x1 = np.maximum(b1_x1, b2_x1)inter_rect_y1 = np.maximum(b1_y1, b2_y1)inter_rect_x2 = np.minimum(b1_x2, b2_x2)inter_rect_y2 = np.minimum(b1_y2, b2_y2)inter_area = np.maximum(inter_rect_x2 - inter_rect_x1, 0) *\np.maximum(inter_rect_y2 - inter_rect_y1, 0)b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)iou = inter_area / np.maximum(b1_area + b2_area - inter_area, 1e-6)return ioudef nms(boxes, nms_thres=0.5):result = []#------------------------------------------## 获得预测结果中包含的所有种类#------------------------------------------#unique_labels = np.unique(boxes[:, -1])for c in unique_labels:#------------------------------------------## 获得某一类得分筛选后全部的预测结果#------------------------------------------#detections_class = boxes[boxes[:, -1] == c]# 按照存在物体的置信度排序conf_sort_index = np.argsort(detections_class[:, 4])[::-1]detections_class = detections_class[conf_sort_index]# 进行非极大抑制max_detections = []while len(detections_class) != 0:# 取出这一类置信度最高的,一步一步往下判断,判断重合程度是否大于nms_thres,如果是则去除掉max_detections.append(np.expand_dims(detections_class[0],axis=0))if len(detections_class) == 1:breakious = bbox_iou(max_detections[-1], detections_class[1:])detections_class = detections_class[1:][ious < nms_thres]# 堆叠max_detections = np.concatenate(max_detections)result.append(max_detections)result = np.concatenate(result)# ===================================================================================== ## 案例绘图 ## ===================================================================================== #plt.figure()colors = [(255,0,0),(0,255,0),(0,0,255)]img = np.zeros((616, 616, 3))plt.subplot(121)for i in result:cv2.rectangle(img, (int(i[0]), int(i[1])), (int(i[2]), int(i[3])), colors[int(i[5])], 2)plt.imshow(img)plt.title('nms')plt.subplot(122)img_ = np.zeros((616, 616, 3))for i in boxes:cv2.rectangle(img_, (int(i[0]), int(i[1])), (int(i[2]), int(i[3])), colors[int(i[5])], 2)plt.imshow(img_)plt.title('original')plt.show()# ========================================================================================== #return result # 返回nms后的结果def soft_nms(boxes,conf_thres=0.5,sigma=0.5):result = []#------------------------------------------## 获得预测结果中包含的所有种类#------------------------------------------#unique_labels = np.unique(boxes[:, -1])for c in unique_labels:#------------------------------------------## 获得某一类得分筛选后全部的预测结果#------------------------------------------#detections_class = boxes[boxes[:, -1] == c]# 按照存在物体的置信度排序conf_sort_index = np.argsort(detections_class[:, 4])[::-1]detections_class = detections_class[conf_sort_index]# 进行非极大抑制max_detections = []while len(detections_class) != 0:# 取出这一类置信度最高的,一步一步往下判断,根据iou设置一个对置信度的惩罚因子,去除置信度小于conf_thres的框max_detections.append(np.expand_dims(detections_class[0],axis=0))if len(detections_class) == 1:breakious = bbox_iou(max_detections[-1], detections_class[1:])detections_class[1:, 4] = np.exp(-(ious * ious) / sigma) * detections_class[1:, 4]detections_class = detections_class[1:]detections_class = detections_class[detections_class[:, 4] >= conf_thres]arg_sort = np.argsort(detections_class[:, 4])[::-1]detections_class = detections_class[arg_sort]# 堆叠max_detections = np.concatenate(max_detections)result.append(max_detections)result = np.concatenate(result)# =================================================================================== ## 案例绘图 ## =================================================================================== #plt.figure()colors = [(255,0,0),(0,255,0),(0,0,255)]img = np.zeros((616, 616, 3))plt.subplot(121)for i in result:cv2.rectangle(img, (int(i[0]), int(i[1])), (int(i[2]), int(i[3])), colors[int(i[5])], 2)plt.imshow(img)plt.title('nms')plt.subplot(122)img_ = np.zeros((616, 616, 3))for i in boxes:cv2.rectangle(img_, (int(i[0]), int(i[1])), (int(i[2]), int(i[3])), colors[int(i[5])], 2)plt.imshow(img_)plt.title('original')plt.show()# ========================================================================================== #return result # 返回nms后的结果if __name__ == '__main__':boxes = np.array([[100, 110, 210, 210, 0.71, 0.7, 0.3, 0.5],[250, 250, 420, 420, 0.8, 0.1, 0.8, 0.6],[220, 200, 320, 330, 0.92, 0.2, 0.5, 1.0],[120, 100, 210, 210, 0.72, 0.8, 0.2, 0.3],[230, 240, 325, 330, 0.81, 0.1, 0.9, 0.2],[220, 230, 315, 340, 0.91, 0.2, 0.7, 0.6]])# 转换成[x1,y1,x2,y2,confident,种类下标]boxes = np.hstack((boxes[...,:5],np.expand_dims(np.argmax(boxes[:,5:],axis=-1),axis=-1)))nms(boxes,nms_thres=0.5)soft_nms(boxes,0.5,0.1) #boxes,threshold,λ(gauss函数参数)
效果图:
![](/assets/blank.gif)
三、总结
本文章分别由算法原理、算法实现步骤和代码展示三个部分组成。如果你觉得本章论文对你有帮助,请点个
目标检测中NMS(non maximum suppression)相关推荐
- 目标检测中NMS和mAP指标中的的IoU阈值和置信度阈值
有时候路走的太远,会忘了为什么要出发. 学习亦如是 在目标检测中,经常看到置信度阈值和IoU阈值这两个关键参数,且NMS计算和mAP计算中都会有这两个,那它们的区别是什么?本文就这个问题做一次总结. ...
- 目标检测中NMS(非极大抑制)的概念理解
参考博客 物体检测中常用的几个概念迁移学习.IOU.NMS理解 目标定位和检测系列(3):交并比(IOU)和非极大值抑制(NMS)的python实现 一.NMS(非极大抑制)概念 NMS即non ma ...
- 目标检测中NMS的GPU实现(来自于Faster R-CNN中的nms_kernel.cu文件)
最近要修改Faster R-CNN中实现的GPU版的NMS代码,于是小白的我就看起了CUDA编程,当然也只是浅显地阅读一些教程,快速入门而已,所以具体需要注意的以及一些思想,大家移步此博主的系列教程: ...
- 【目标检测】NMS和soft-NMS详解及代码实现
1. NMS 1.1. NMS概述 非极大值抑制(Non-Maximum Suppression, NMS),顾名思义就是抑制不是极大值的元素,用于目标检测中,就是提取置信度高的目标检测框,而抑制置信 ...
- 综合评价模型的缺点_【必备】目标检测中的评价指标有哪些?
在人工智能领域,机器学习的效果需要用各种指标来评价.当一个目标检测模型建立好了之后,即模型训练已经完成,我们就可以利用这个模型进行分类识别.那么该如何去评价这个模型的性能呢? 上期我们一起学习了全卷积 ...
- 目标检测中的样本不平衡处理方法——OHEM, Focal Loss, GHM, PISA
GitHub 简书 CSDN 文章目录 1. 前言 2. OHEM 3. Focal Loss 3.1 Cross Entropy 3.2 Balanced Cross Entropy 3.3 Foc ...
- 目标检测中的Two-stage的检测算法
比较详细,作个备份 什么是目标检测(object detection): 目标检测(object detection),就是在给定的一张图片中精确找到物体所在的位置,并标注出物体的类别.所以,目标检测 ...
- 目标检测中特征融合技术(YOLO v4)(上)
目标检测中特征融合技术(YOLO v4)(上) 论文链接:https://arxiv.org/abs/1612.03144 Feature Pyramid Networks for Object De ...
- 目标检测中的Tricks
点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 来自 | 知乎 作者 | roger 链接 | https: ...
最新文章
- iOS开发UI篇—UITableview控件基本使用
- OBYC中的GBB一般修改的解释
- android jar 反射,android 第三方jar库 反射得到自己的资源ID
- Visual Studio 2010 SDK
- fastjson 大写转小写 字段_对象转json字符串,属性首字母大写自动变为小写
- 朗沃20140414
- android自定义表格布局
- ANS1编码详解(二)--编码规则
- Google Play 新增付款功能一览表
- 使用dom4j来解析xml文件或xml字符串
- 奇安信Java后端一面
- 判断bug属于前端还是后端
- 该如何选择Java开发和嵌入式开发
- CSS 设置垂直居中
- 今年国庆,我选择给自己充电
- 闪蝶-COBOL代码分析工具
- Au 中英文版本切换批处理文件
- “GANs 之父”Goodfellow亲身传授:深度学习未来的8大方向和入门AI必备的三大技能
- js+css+html制作简易留言板
- centos上开通FTP,真正可用
热门文章
- UESTC 1642 老当益壮, 宁移白首之心? 欧拉回路、Fleury算法
- Linux C/C++ or 嵌入式面试之《多进程多线程编程系列》(4) 进程同步和通信的方式有哪些?
- W3B x Sui Hacker House|深入了解Sui和Move语言
- Day19 网络编程
- python学习笔记之-展平函数ravel和flatten及两者的区别
- 打印机打印为什么显示服务器脱机,打印机显示服务器脱机怎么办
- 二~十进制计数器仿真原理(基于proteus)
- leetcode 467 c语言. Unique Substrings in Wraparound String
- CorePlot_1.5.1 绘制散点图(折线图、曲线图、直方图)
- 第13期5G消息云课堂大咖分享|联动云通信副总裁王鹏