CenterNet 数据加载解析
本文主要解读CenterNet如何加载数据,并将标注信息转化为CenterNet规定的高斯分布的形式。
1. YOLOv3和CenterNet流程对比
CenterNet和Anchor-Based的方法不同,以YOLOv3为例,大致梳理一下模型的框架和数据处理流程。
YOLOv3是一个经典的单阶段的目标检测算法,图片进入网络的流程如下:
- 对图片进行resize,长和宽都要是32的倍数。
- 图片经过网络的特征提取后,空间分辨率变为原来的1/32。
- 得到的Tensor去代表图片不同尺度下的目标框,其中目标框的表示为(x,y,w,h,c),分别代表左上角坐标,宽和高,含有某物体的置信度。
- 训练完成后,测试的时候需要使用非极大抑制算法得到最终的目标框。
CenterNet是一个经典的Anchor-Free目标检测方法,图片进入网络流程如下:
- 对图片进行resize,长和宽一般相等,并且至少为4的倍数。
- 图片经过网络的特征提取后,得到的特征图的空间分辨率依然比较大,是原来的1/4。这是因为CenterNet采用的是类似人体姿态估计中用到的骨干网络,基于heatmap提取关键点的方法需要最终的空间分辨率比较大。
- 训练的过程中,CenterNet得到的是一个heatmap,所以标签加载的时候,需要转为类似的heatmap热图。
- 测试的过程中,由于只需要从热图中提取目标,这样就不需要使用NMS,降低了计算量。
2. CenterNet部分详解
设输入图片为 I ∈ R W × H × 3 I\in R^{W\times H\times 3} I∈RW×H×3, W代表图片的宽,H代表高。CenterNet的输出是一个关键点热图heatmap。
Y ^ ∈ [ 0 , 1 ] W R × H R × C \hat{Y}\in[0,1]^{\frac{W}{R}\times\frac{H}{R}\times C} Y^∈[0,1]RW×RH×C
其中R代表输出的stride大小,C代表关键点的类型的个数。
举个例子,在COCO数据集目标检测中,R设置为4,C的值为80,代表80个类别。
如果 Y ^ x , y , c = 1 \hat{Y}_{x,y,c}=1 Y^x,y,c=1代表检测到一个物体,表示对类别c来说,(x,y)这个位置检测到了c类的目标。
既然输出是热图,标签构建的ground truth也必须是热图的形式。标注的内容一般包含(x1,y1,x2,y2,c),目标框左上角坐标、右下角坐标和类别c,按照以下流程转为ground truth:
- 得到原图中对应的中心坐标 p = ( x 1 + x 2 2 , y 1 + y 2 2 ) p=(\frac{x1+x2}{2}, \frac{y1+y2}{2}) p=(2x1+x2,2y1+y2)
- 得到下采样后的feature map中对应的中心坐标 p ~ = ⌊ p R ⌋ \tilde{p}=\lfloor \frac{p}{R}\rfloor p~=⌊Rp⌋, R代表下采样倍数,CenterNet中R为4
- 如果输入图片为512,那么输出的feature map的空间分辨率为[128x128], 将标注的目标框以高斯核的方式将关键点分布到特征图上:
Y x y c = e x p ( − ( x − p ~ x ) 2 + ( y − p ~ y ) 2 2 σ p 2 ) Y_{xyc}=exp(-\frac{(x-\tilde p_x)^2+(y-\tilde p_y)^2}{2\sigma ^2_p}) Yxyc=exp(−2σp2(x−p~x)2+(y−p~y)2)
其中 σ p \sigma_p σp是一个与目标大小相关的标准差(代码中设置的是)。对于特殊情况,相同类别的两个高斯分布发生了重叠,重叠元素间最大的值作为最终元素。下图是知乎用户OLDPAN分享的高斯分布图。
3. 代码部分
datasets/pascal.py 的代码主要从getitem函数入手,以下代码已经做了注释,其中最重要的两个部分一个是如何获取高斯半径(gaussian_radius函数),一个是如何将高斯分布分散到heatmap上(draw_umich_gaussian函数)。
def __getitem__(self, index):img_id = self.images[index]img_path = os.path.join(self.img_dir, self.coco.loadImgs(ids=[img_id])[0]['file_name'])ann_ids = self.coco.getAnnIds(imgIds=[img_id])annotations = self.coco.loadAnns(ids=ann_ids)labels = np.array([self.cat_ids[anno['category_id']]for anno in annotations])bboxes = np.array([anno['bbox']for anno in annotations], dtype=np.float32)if len(bboxes) == 0:bboxes = np.array([[0., 0., 0., 0.]], dtype=np.float32)labels = np.array([[0]])bboxes[:, 2:] += bboxes[:, :2] # xywh to xyxyimg = cv2.imread(img_path)height, width = img.shape[0], img.shape[1]# 获取中心坐标pcenter = np.array([width / 2., height / 2.],dtype=np.float32) # center of imagescale = max(height, width) * 1.0 # 仿射变换flipped = Falseif self.split == 'train':# 随机选择一个尺寸来训练scale = scale * np.random.choice(self.rand_scales)w_border = get_border(128, width)h_border = get_border(128, height)center[0] = np.random.randint(low=w_border, high=width - w_border)center[1] = np.random.randint(low=h_border, high=height - h_border)if np.random.random() < 0.5:flipped = Trueimg = img[:, ::-1, :]center[0] = width - center[0] - 1# 仿射变换trans_img = get_affine_transform(center, scale, 0, [self.img_size['w'], self.img_size['h']])img = cv2.warpAffine(img, trans_img, (self.img_size['w'], self.img_size['h']))# 归一化img = (img.astype(np.float32) / 255.)if self.split == 'train':# 对图片的亮度对比度等属性进行修改color_aug(self.data_rng, img, self.eig_val, self.eig_vec)img -= self.meanimg /= self.stdimg = img.transpose(2, 0, 1) # from [H, W, C] to [C, H, W]# 对Ground Truth heatmap进行仿射变换trans_fmap = get_affine_transform(center, scale, 0, [self.fmap_size['w'], self.fmap_size['h']]) # 这时候已经是下采样为原来的四分之一了# 3个最重要的变量hmap = np.zeros((self.num_classes, self.fmap_size['h'], self.fmap_size['w']), dtype=np.float32) # heatmapw_h_ = np.zeros((self.max_objs, 2), dtype=np.float32) # width and heightregs = np.zeros((self.max_objs, 2), dtype=np.float32) # regression# indexsinds = np.zeros((self.max_objs,), dtype=np.int64)# 具体选择哪些indexind_masks = np.zeros((self.max_objs,), dtype=np.uint8)for k, (bbox, label) in enumerate(zip(bboxes, labels)):if flipped:bbox[[0, 2]] = width - bbox[[2, 0]] - 1# 对检测框也进行仿射变换bbox[:2] = affine_transform(bbox[:2], trans_fmap)bbox[2:] = affine_transform(bbox[2:], trans_fmap)# 防止越界bbox[[0, 2]] = np.clip(bbox[[0, 2]], 0, self.fmap_size['w'] - 1)bbox[[1, 3]] = np.clip(bbox[[1, 3]], 0, self.fmap_size['h'] - 1)# 得到高和宽h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]if h > 0 and w > 0:obj_c = np.array([(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2], dtype=np.float32) # 中心坐标-浮点型obj_c_int = obj_c.astype(np.int32) # 整型的中心坐标# 根据一元二次方程计算出最小的半径radius = max(0, int(gaussian_radius((math.ceil(h), math.ceil(w)), self.gaussian_iou)))# 得到高斯分布draw_umich_gaussian(hmap[label], obj_c_int, radius)w_h_[k] = 1. * w, 1. * h# 记录偏移量regs[k] = obj_c - obj_c_int # discretization error# 当前是obj序列中的第k个 = fmap_w * cy + cx = fmap中的序列数inds[k] = obj_c_int[1] * self.fmap_size['w'] + obj_c_int[0]# 进行mask标记ind_masks[k] = 1return {'image': img, 'hmap': hmap, 'w_h_': w_h_, 'regs': regs, 'inds': inds, 'ind_masks': ind_masks, 'c': center, 's': scale, 'img_id': img_id}
4. heatmap上应用高斯核
heatmap上使用高斯核有很多需要注意的细节。CenterNet官方版本实际上是在CornerNet的基础上改动得到的,有很多祖传代码。
在使用高斯核前要考虑这样一个问题,下图来自于CornerNet论文中的图示,红色的是标注框,但绿色的其实也可以作为最终的检测结果保留下来。那么这个问题可以转化为绿框在红框多大范围以内可以被接受。使用IOU来衡量红框和绿框的贴合程度,当两者IOU>0.7的时候,认为绿框也可以被接受,反之则不被接受。
那么现在问题转化为,如何确定半径r, 让红框和绿框的IOU大于0.7。
以上是三种情况,其中蓝框代表标注框,橙色代表可能满足要求的框。这个问题最终变为了一个一元二次方程有解的问题,同时由于半径必须为正数,所以r的取值就可以通过求根公式获得。
def gaussian_radius(det_size, min_overlap=0.7):# gt框的长和宽height, width = det_sizea1 = 1b1 = (height + width)c1 = width * height * (1 - min_overlap) / (1 + min_overlap)sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1)r1 = (b1 + sq1) / (2 * a1)a2 = 4b2 = 2 * (height + width)c2 = (1 - min_overlap) * width * heightsq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2)r2 = (b2 + sq2) / (2 * a2)a3 = 4 * min_overlapb3 = -2 * min_overlap * (height + width)c3 = (min_overlap - 1) * width * heightsq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3)r3 = (b3 + sq3) / (2 * a3)return min(r1, r2, r3)
可以看到这里的公式和上图计算的结果是一致的,需要说明的是,CornerNet最开始版本中这里出现了错误,分母不是2a,而是直接设置为2。CenterNet也延续了这个bug,CenterNet作者回应说这个bug对结果的影响不大,但是根据issue的讨论来看,有一些人通过修正这个bug以后,可以让AR提升1-3个百分点。以下是有bug的版本,CornerNet最新版中已经修复了这个bug。
def gaussian_radius(det_size, min_overlap=0.7):height, width = det_sizea1 = 1b1 = (height + width)c1 = width * height * (1 - min_overlap) / (1 + min_overlap)sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1)r1 = (b1 + sq1) / 2a2 = 4b2 = 2 * (height + width)c2 = (1 - min_overlap) * width * heightsq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2)r2 = (b2 + sq2) / 2a3 = 4 * min_overlapb3 = -2 * min_overlap * (height + width)c3 = (min_overlap - 1) * width * heightsq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3)r3 = (b3 + sq3) / 2return min(r1, r2, r3)
同时有一些人认为圆并不普适,提出了使用椭圆来进行计算,也有人在issue中给出了推导,感兴趣的可以看以下链接:https://github.com/princeton-vl/CornerNet/issues/110
5. 高斯分布添加到heatmap上
def gaussian2D(shape, sigma=1):m, n = [(ss - 1.) / 2. for ss in shape]y, x = np.ogrid[-m:m + 1, -n:n + 1]h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))h[h < np.finfo(h.dtype).eps * h.max()] = 0# 限制最小的值return hdef draw_umich_gaussian(heatmap, center, radius, k=1):# 得到直径diameter = 2 * radius + 1gaussian = gaussian2D((diameter, diameter), sigma=diameter / 6) # sigma是一个与直径相关的参数# 一个圆对应内切正方形的高斯分布x, y = int(center[0]), int(center[1])height, width = heatmap.shape[0:2]# 对边界进行约束,防止越界left, right = min(x, radius), min(width - x, radius + 1)top, bottom = min(y, radius), min(height - y, radius + 1)# 选择对应区域masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]# 将高斯分布结果约束在边界内masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:radius + right]if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0: # TODO debugnp.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap)# 将高斯分布覆盖到heatmap上,相当于不断的在heatmap基础上添加关键点的高斯,# 即同一种类型的框会在一个heatmap某一个类别通道上面上面不断添加。# 最终通过函数总体的for循环,相当于不断将目标画到heatmapreturn heatmap
使用matplotlib对gaussian2D进行可视化。
import numpy as np
y,x = np.ogrid[-4:5,-3:4]
sigma = 1
h=np.exp(-(x*x+y*y)/(2*sigma*sigma))
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = Axes3D(fig)
ax.plot_surface(x,y,h)
plt.show()
6. 参考
[1]https://zhuanlan.zhihu.com/p/66048276
[2]https://www.cnblogs.com/shine-lee/p/9671253.html
[3]https://zhuanlan.zhihu.com/p/96856635
[4]http://xxx.itp.ac.cn/pdf/1808.01244
[5]https://github.com/princeton-vl/CornerNet/issues/110
CenterNet 数据加载解析相关推荐
- dataset__getitem___PyTorch源码解析与实践(1):数据加载Dataset,Sampler与DataLoader
献给学习PyTorch在路上或者计划较深入理解PyTorch的同行者们 写在前面 笔者一直使用tf,大势所趋决定转PyTorch,这个系列就作为我学习PyTorch的笔记与心得. 网络上PyTorch ...
- nuScenes自动驾驶数据集:格式转换,模型的数据加载(二)
文章目录 一.nuScenes数据集格式精解 二.nuScenes数据格式转换(To COCO) 数据格式转换框架 2.1 核心:convert_nuScenes.py解析 其他格式转换文件 2.1. ...
- nuScenes自动驾驶数据集:数据格式精解,格式转换,模型的数据加载 (一)
nuScenes数据集及nuScenes开发工具包简介 文章目录 nuScenes数据集及nuScenes开发工具包简介 1.1. nuScenes数据集简介: 1.2 数据采集: 1.2.1 传感器 ...
- PyTorch数据加载处理
PyTorch数据加载处理 PyTorch提供了许多工具来简化和希望数据加载,使代码更具可读性. 1.下载安装包 • scikit-image:用于图像的IO和变换 • pandas:用于更容易地进行 ...
- hive外部表改为内部表_3000字揭秘Greenplum的外部数据加载——外部表
外部表是greenplum的一种数据表,它与普通表不同的地方是:外部表是用来访问存储在greenplum数据库之外的数据.如普通表一样,可使用SQL对外部表进行查询和插入操作.外部表主要用于Green ...
- python使用matplotlib, seaborn画图时候的数据加载
写在前面的话 当我们使用python来画图的时候,我觉得最难的部分应该是数据加载.因为尽管官网的教程给出了怎么画出某个图片的示例,但是数据往往是随机产生的,这些数据和我们需要的数据往往是不符的.这个时 ...
- Python之pandas数据加载、存储
Python之pandas数据加载.存储 0. 输入与输出大致可分为三类: 0.1 读取文本文件和其他更好效的磁盘存储格式 2.2 使用数据库中的数据 0.3 利用Web API操作网络资源 1. 读 ...
- Hibernate懒加载解析
Hibernate懒加载解析 在Hibernate框架中,当我们要访问的数据量过大时,明显用缓存不太合适, 因为内存容量有限 ,为了减少并发量,减少系统资源的消耗,这时Hibernate用懒加载机制来 ...
- MPP 二、Greenplum数据加载
Loading external data into greenplum database table using different ways... Greenplum 有常规的COPY加载方法,有 ...
最新文章
- VC下提前注入进程的一些方法2——远线程带参数
- 正向最大匹配算法 python代码_中文分词算法之最大正向匹配算法(Python版)
- footer在最低显示
- 334. Increasing Triplet Subsequence
- Spark: sortBy和sortByKey函数详解
- poj 1160 dp
- CCF NOI1044 最近元素
- linux核心设计ebpf,Linux eBPF介绍
- 路由器需要多大内存?
- 独立游戏开发日志:2021年2月12日 改进版反弹跳
- 基于SRP创建自定义渲染管线
- ImageJ的单细胞荧光强度分析
- Linux--用xmanager远程管理的设定过程
- halcon学习拓展系列—修改图片分辨率算子modify_image_size(尺度不缩放)
- 解决默认浏览器被劫持
- 网易CEO丁磊:手机网游是未来新趋势
- ps把白底图片改为透明
- 【ArcGIS教程】(1)带有经纬度的EXCEL数据如何转换为shp矢量数据?
- php where 时间条件,thinkphp5日期时间查询比较和whereTime使用方法
- MacBook进水记