arXiv-2020


文章目录

  • 1 Background and Motivation
  • 2 Related Work
  • 3 Advantages / Contributions
  • 4 GridMask
  • 5 Experiments
    • 5.1 Image Classification
    • 5.2 Object Detection on COCO Dataset
    • 5.3 Semantic Segmentation on Cityscapes
    • 5.4 Expand Grid as Regularization
  • 6 Conclusion(own)

1 Background and Motivation

数据增广方法可以有效的缓解模型的过拟合

现有的数据增广方法可以大致分成如下3类

  • spatial transformation(random scale, crop, flip and random rotation)
  • color distortion( brightness, hue)
  • information dropping(random erasing, cutout,HaS)

好的 information dropping 数据增广方法要 achieve reasonable balance between deletion and reserving of regional information on the images

删太多,把数据变成了噪声

删太少,目标没啥变化,失去了增广的意义


本文,作者提出GridMask,deletes uniformly distributed areas and finally forms a grid shape,在多个任务的公开数据集上效果均有提升

2 Related Work

  • spatial transformation(random scale, crop, flip and random rotation)
  • color distortion( brightness, hue)
  • information dropping(random erasing, cutout,HaS)

3 Advantages / Contributions

提出 GridMask structured data augmentation 方法,在公开的分类、目标检测、分割的benchmark 上比 baseline 好

4 GridMask


作用形式
x ~ = x × M \widetilde{x}= x \times M x =x×M

其中 x ∈ R H × W × C x \in \mathbb{R}^{H \times W \times C} x∈RH×W×C 为 输入图像, x ~ ∈ R H × W × C \widetilde{x} \in \mathbb{R}^{H \times W \times C} x ∈RH×W×C 为增广后的图像, M ∈ { 0 , 1 } H × W M \in \{0,1\}^{H \times W} M∈{0,1}H×W 为 binary mask that stores pixels to be removed,0 的话表示挡住,1 的话表示保留

形成 M M M 的话有 4 个超参数 ( r , d , δ x , δ y ) (r, d, \delta_x, \delta_y) (r,d,δx​,δy​)


1)Choice of r r r

r r r is the ratio of the shorter gray edge in a unit,determines the keep ratio of an input image,值介于 0~1 之间

the keep ratio k k k of a given mask M M M as

k = s u m ( M ) H × W k = \frac{sum(M)}{H \times W} k=H×Wsum(M)​

r r r 和 k k k 的关系是

k = 1 − ( 1 − r ) 2 = 2 r − r 2 k = 1-(1-r)^2 = 2r-r^2 k=1−(1−r)2=2r−r2

r r r 的值小于1, r r r 和 k k k 正相关

k k k 越大,灰色区域越多,遮挡越少
k k k 越小,黑色区域越多,遮挡越多

2)Choice of d d d

d d d is the length of one unit

一个 unit 内(橙色虚线框),灰色区域的长度为 l = r × d l = r \times d l=r×d

d = r a n d o m ( d m i n , d m a x ) d = random(d_{min}, d_{max}) d=random(dmin​,dmax​)


这么画歧义更合适

3)Choice of δ x \delta_x δx​ and δ y \delta_y δy​

δ x \delta_x δx​ and δ y \delta_y δy​ are the distances between the first intact unit and boundary of the image. can shift the mask

δ x ( δ y ) = r a n d o m ( 0 , d − 1 ) \delta_x(\delta_y) = random(0, d-1) δx​(δy​)=random(0,d−1)

4)Statistics of Unsuccessful Cases

99 percent of an object is removed or reserved, we call it a failure case

GridMask has lower chance to yield failure cases than Cutout and HaS

5)The Scheme to Use GridMask

increase the probability of GridMask linearly with the training epochs until an upper bound P is achieved.

中间的概率用 p p p 表示,后续实验中有涉及到

5 Experiments

Datasets

  • ImageNet
  • COCO
  • Cityscapes

5.1 Image Classification

1)ImageNet

比 Cutout 和 HaS 更好,It is because we handle the aforementioned failure cases better

Benefit to CNN

focus on large important regions

2)CIFAR10

Combined with AutoAugment, we achieve SOTA result on these models.

3)Ablation Study

(1)Hyperparameter r r r

r 越大,mask 1 越多,遮挡的越少,说明数据比较复杂

r 越小,mask 1 越少,遮挡的越多,说明数据比较简单

we should keep more information on complex datasets to avoid under-fitting, and delete more on simple datasets to reduce over-fitting

(2)Hyperparameter d d d

the diversity of d can increase robustness of the network

(3)Variations of GridMask

reversed GridMask:keep what we drop in GridMask, and drop what we keep in GridMask


效果不错,也印证了 GridMask 有很好的 balance between deletion and reserving

random GridMask:drop a block in every unit with a certain probability of p u p_u pu​.

p u p_u pu​ 越大,越贴近原始 GridMask

效果不行

5.2 Object Detection on COCO Dataset


不加 GridMask,training epochs 越多,过拟合越严重,加了以后,训练久一点, 精度还有上升空间

5.3 Semantic Segmentation on Cityscapes

5.4 Expand Grid as Regularization

联合 GridMask 和 Mixup,ImageNet 上 SOTA

6 Conclusion(own)

GridMask Data Augmentation


代码实现,考虑了旋转增广,所以 mask 生成的时候是在以原图对角线为边长的情况下生成的,最后取原图区域
https://github.com/dvlab-research/GridMask/blob/master/imagenet_grid/utils/grid.py

import torch
import numpy as np
import math
import PIL.Image as Image
import torchvision.transforms as T
import matplotlib.pyplot as pltclass Grid(object):def __init__(self, d1=96, d2=224, rotate=1, ratio=0.5, mode=1, prob=1.):self.d1 = d1self.d2 = d2self.rotate = rotateself.ratio = ratio # rself.mode = mode # reversed?self.st_prob = self.prob = prob # pdef set_prob(self, epoch, max_epoch):self.prob = self.st_prob * min(1, epoch / max_epoch)def forward(self, img):if np.random.rand() > self.prob:return imgh = img.size(1)w = img.size(2)# 1.5 * h, 1.5 * w works fine with the squared images# But with rectangular input, the mask might not be able to recover back to the input image shape# A square mask with edge length equal to the diagnoal of the input image # will be able to cover all the image spot after the rotation. This is also the minimum square.hh = math.ceil((math.sqrt(h * h + w * w)))d = np.random.randint(self.d1, self.d2)# d = self.d# maybe use ceil? but i guess no big differenceself.l = math.ceil(d * self.ratio)mask = np.ones((hh, hh), np.float32)st_h = np.random.randint(d)  # delta yst_w = np.random.randint(d)  # delta xfor i in range(-1, hh // d + 1):s = d * i + st_ht = s + self.ls = max(min(s, hh), 0)t = max(min(t, hh), 0)mask[s:t, :] *= 0for i in range(-1, hh // d + 1):s = d * i + st_wt = s + self.ls = max(min(s, hh), 0)t = max(min(t, hh), 0)mask[:, s:t] *= 0r = np.random.randint(self.rotate)mask = Image.fromarray(np.uint8(mask))mask = mask.rotate(r)mask = np.asarray(mask)mask = mask[(hh - h) // 2:(hh - h) // 2 + h, (hh - w) // 2:(hh - w) // 2 + w] # 这里结合原理图方便看懂一些mask = torch.from_numpy(mask).float().cuda()if self.mode == 1:mask = 1 - maskmask = mask.expand_as(img)img = img.cuda() * maskreturn imgif __name__ == "__main__":image = Image.open("2.jpg").convert("RGB")tr = T.Compose([T.Resize((224,224)),T.ToTensor()])x = tr(image)gridmask_image = Grid(d1=64, d2=96).forward(x)print(gridmask_image.shape)# print(gridmask_image.shape())fig, axs = plt.subplots(1,2)to_plot = lambda x: x.permute(1,2,0).cpu().numpy()axs[0].imshow(to_plot(x))axs[1].imshow(to_plot(gridmask_image))plt.show()

【GridMask】《GridMask Data Augmentation》相关推荐

  1. 【LeetCode】《剑指Offer》第Ⅴ篇⊰⊰⊰ 39 - 47题

    [LeetCode]<剑指Offer>第Ⅴ篇⊰⊰⊰ 39 - 47题 文章目录 [LeetCode]<剑指Offer>第Ⅴ篇⊰⊰⊰ 39 - 47题 39. 数组中出现次数超过 ...

  2. 【原创】【专栏】《Linux设备驱动程序》--- LDD3源码目录结构和源码分析经典链接

    http://blog.csdn.net/geng823/article/details/37567557 [原创][专栏]<Linux设备驱动程序>--- LDD3源码目录结构和源码分析 ...

  3. 搜索引擎早期重要论文推荐系列【7】《Searching the Web》

    搜索引擎早期重要论文推荐系列[7]<Searching the Web> - pennyliang的专栏 - 博客频道 - CSDN.NET 搜索引擎早期重要论文推荐系列[7]<Se ...

  4. 【LeetCode】《剑指Offer》第Ⅰ篇⊰⊰⊰ 3 - 11题

    [LeetCode]<剑指Offer>第Ⅰ篇⊰⊰⊰ 3 - 11题 文章目录 [LeetCode]<剑指Offer>第Ⅰ篇⊰⊰⊰ 3 - 11题 03. 数组中重复的数字(ea ...

  5. 【Python】《Python语言程序设计》(嵩天 、黄天羽 、礼欣)测验单项选择题答案与解析合辑

    [Python]<Python语言程序设计>(嵩天 .黄天羽 .礼欣)测验单项选择题答案与解析合辑 测验1:Python基本语法元素(第1周) 测验2:Python基本图形绘制(第2周) ...

  6. 【ACM- OJ】《九折?》C++

    [ACM- OJ]<九折?>C++ 题目描述 输入 输出 样例输入 样例输出 提示 AC代码 题目描述 bfs的作业比较简单,所以有的同学会凭借强大的编码能力说作业九折? 出题人为了不让这 ...

  7. 【资源】《动手学数据分析》开源教程完整发布!

    作者:陈安东,湖南大学,Datawhale成员 1. 开源初衷 对于任何一个将来要实际运用的技能,通过实战,自己亲自将一行行代码敲出来,然后达到自己想要的效果,这个过程是最好的学习方式. 最开始接触了 ...

  8. 方舟生存进化mysql_【游戏】《方舟生存进化》怎么联机 搭建服务器联机教程

    想玩下<方舟生存进化>,跟小伙伴们一起 和小伙伴们一起联机打恐龙,是该有多爽!<方舟生存进化>是一款题材十分新颖的沙盒生存游戏,那么方舟生存进化怎么联机?下面为大家介绍< ...

  9. 【WebGL】《WebGL编程指南》读书笔记——第2章

    一.前言 最近看了<WebGL编程指南>这本书,发现还是很有意思的,故每章阅读后做个笔记. 二.正文 Example1:在canvas中绘制2D矩形 <!DOCTYPE html&g ...

最新文章

  1. linux c语言 glibc 错误 munmap,Linux内存分配小结--malloc、brk、mmap
  2. 最穷的日子,你是如何熬过来的?
  3. centeros下安装python
  4. 牛客网练习赛26B(简单的dp)
  5. SpringMVC表单标签
  6. 20144303 《Java程序设计》第一周学习总结
  7. 论文笔记_S2D.35-2017-IROS_利用CNNs联合预测RGB图像的深度、法线和表面曲率
  8. arduino学习系列——DHT11温湿度传感器的使用
  9. 华三服务器管理口地址_各种服务器、存储默认管理IP地址以及用户名密码
  10. C51单片机串口通信之上位机交互
  11. how the sold to party and ship to party determined in IDOC#
  12. 【jdk1.8特性】之Function
  13. 北航超算运行matlab,工信部网:北航学子荣获ASC19世界大学生超算竞赛最高计算性能奖...
  14. CIO如何在企业并购中生存
  15. CodeVS3287[NOIP2013] 货车运输【Kruskal+倍增求LCA】
  16. python初学者游戏开发团队
  17. vue中加载腾讯地图(html形式)
  18. 3年Python编程自学经历,分享一些心得经验
  19. 关于iphone、QQ通讯录、飞聊联系人排序设计的思考
  20. stm32使用SPI对W25Q64--8M字节FLASH的读写

热门文章

  1. 怎么看互联网反垄断下的大厂未来?
  2. 2022-2028年中国祭祀用品行业市场竞争状况及发展趋向分析报告
  3. 百度网盘无法下载怎么办?
  4. h3c端口聚合实现服务器增加带宽,H3C动态链路聚合对接服务器双网卡
  5. 脉冲星 1 月脉动 | Pulsar 2.5.0 和 Pulsarctl 0.3.0 发布,多项活动期待您的参与
  6. java集成腾讯地图并获取用户附近商家
  7. C++ winpcap网络抓包代码实现,以及抓包内容解析。
  8. 修改docker自定义网桥后的默认网卡名称
  9. 《C语言》课程设计——火车票信息管理系统
  10. 2021年全球灌装设备收入大约1194.6百万美元,预计2028年达到1604.7百万美元