一、数据集相关代码解读

创建dataloader(damo/dataset/build.py)

在damo/apis/detector_trainer.py的158行,及174-203行中,DAMO-YOLO分别对train_dataloader和val_dataloader进行了创建,并进行了iters_per_epoch的计算,用于后续Iters-based的模型训练。

# dataloader
self.train_loader, self.val_loader, iters = self.get_data_loader(cfg)

build_dataset函数创建数据集类,若为训练模式,且mosaic_mixup为True时,则会使用MosaicWrapper对dataset类进行封装。

    def get_data_loader(self, cfg):train_dataset = build_dataset(cfg,cfg.dataset.train_ann,is_train=True,mosaic_mixup=cfg.train.augment.mosaic_mixup)val_dataset = build_dataset(cfg, cfg.dataset.val_ann, is_train=False)iters_per_epoch = math.ceil(len(train_dataset[0]) /cfg.train.batch_size)  # train_dataset is a list, however,

创建完dataset类后,即可创建dataloader对数据集进行读取。在dataloader创建函数中,作者基于config提供的batch_size、augmentations、total_epochs、num_workers进行相关超参设置。

        train_loader = build_dataloader(train_dataset,cfg.train.augment,batch_size=cfg.train.batch_size,start_epoch=self.start_epoch,total_epochs=cfg.train.total_epochs,num_workers=cfg.miscs.num_workers,is_train=True,size_div=32)val_loader = build_dataloader(val_dataset,cfg.test.augment,batch_size=cfg.test.batch_size,num_workers=cfg.miscs.num_workers,is_train=False,size_div=32)return train_loader, val_loader, iters_per_epoch

创建COCO数据集(damo/dataset/datasets/coco.py)

__init__解读

COCODataset继承于pycocotools库的CocoDetection类。将json标注中的类别id和连续id进行相互映射,保存在json_category_id_to_contiguous_id和contiguous_category_id_to_json_id两个字典里面。

class COCODataset(CocoDetection):def __init__(self, ann_file, root, transforms=None):super(COCODataset, self).__init__(root, ann_file)# sort indices for reproducible resultsself.ids = sorted(self.ids)self.json_category_id_to_contiguous_id = {v: i + 1for i, v in enumerate(self.coco.getCatIds())}self.contiguous_category_id_to_json_id = {v: kfor k, v in self.json_category_id_to_contiguous_id.items()}self.id_to_img_map = {k: v for k, v in enumerate(self.ids)}self._transforms = transforms

__getitem__解读

该函数在damo/apis/detector_trainer.py中的271行enumerate(self.train_loader)中被调用。

    def __getitem__(self, inp):if type(inp) is tuple:idx = inp[1]else:idx = inpimg, anno = super(COCODataset, self).__getitem__(idx)

从json文件中读出标注框、类别、keypoints等信息,对标注框及iscrowd标签的类别进行过滤。

        # filter crowd annotations# TODO might be better to add an extra fieldanno = [obj for obj in anno if obj['iscrowd'] == 0]boxes = [obj['bbox'] for obj in anno]boxes = torch.as_tensor(boxes).reshape(-1, 4)  # guard against no boxestarget = BoxList(boxes, img.size, mode='xywh').convert('xyxy')classes = [obj['category_id'] for obj in anno]classes = [self.json_category_id_to_contiguous_id[c] for c in classes]classes = torch.tensor(classes)target.add_field('labels', classes)if anno and 'keypoints' in anno[0]:keypoints = [obj['keypoints'] for obj in anno]target.add_field('keypoints', keypoints)target = target.clip_to_image(remove_empty=True)

作者将图像从PIL格式转为numpy格式,之后进行通用的数据增强处理,返回图像、标注、图像id。

# PIL to numpy arrayimg = np.asarray(img)  # rgbif self._transforms is not None:img, target = self._transforms(img, target)return img, target, idx

pull_item解读

pull_item函数主要用于mosaic增强时,读取额外的三张图像。读数据的流程__getitem__基本一致。

    def pull_item(self, idx):img, anno = super(COCODataset, self).__getitem__(idx)# filter crowd annotations# TODO might be better to add an extra fieldanno = [obj for obj in anno if obj['iscrowd'] == 0]boxes = [obj['bbox'] for obj in anno]boxes = torch.as_tensor(boxes).reshape(-1, 4)  # guard against no boxestarget = BoxList(boxes, img.size, mode='xywh').convert('xyxy')target = target.clip_to_image(remove_empty=True)classes = [obj['category_id'] for obj in anno]classes = [self.json_category_id_to_contiguous_id[c] for c in classes]

区别在于pull_item中,若标注中存在segmentation信息,作者将会读出用于进行框的refine.

        obj_masks = []for obj in anno:obj_mask = []if 'segmentation' in obj:for mask in obj['segmentation']:obj_mask += maskif len(obj_mask) > 0:obj_masks.append(obj_mask)seg_masks = [np.array(obj_mask, dtype=np.float32).reshape(-1, 2)for obj_mask in obj_masks]res = np.zeros((len(target.bbox), 5))for idx in range(len(target.bbox)):res[idx, 0:4] = target.bbox[idx]res[idx, 4] = classes[idx]

另外的区别为,作者在Mosaic图进行拼接完成后再去做augmentation,因此,pull_item函数中不包含augmentation操作,将image从PIL转为numpy格式后直接返回。

        img = np.asarray(img)  # rgbreturn img, res, seg_masks, idx

使用MosaicWrapper对CocoDataset封装,进行Mosaic、Mixup数据增强(damo/dataset/datasets/mosaic_wrapper.py)

__init__解读

在初始化过程中,会传入待封装的dataset类、输入尺度input_dim、数据增强方式transforms及(degree, scale, shear)等各种超参。

class MosaicWrapper(torch.utils.data.dataset.Dataset):"""Detection dataset wrapper that performs mixup for normal dataset."""def __init__(self,dataset,img_size,mosaic_prob=1.0,mixup_prob=1.0,transforms=None,degrees=10.0,translate=0.1,mosaic_scale=(0.1, 2.0),mixup_scale=(0.5, 1.5),shear=2.0,*args):super().__init__()self._dataset = datasetself.input_dim = img_sizeself._transforms = transformsself.degrees = degreesself.translate = translateself.scale = mosaic_scaleself.shear = shearself.mixup_scale = mixup_scaleself.mosaic_prob = mosaic_probself.mixup_prob = mixup_probself.local_rank = get_rank()

__getitem__解读

若训练时mosaic_mixup为True且使用MosaicWrapper封装了dataset,则damo/apis/detector_trainer.py中的271行enumerate(self.train_loader)将调用这个函数。

首先读出第一张图像作为基础图像。

def __getitem__(self, inp):if type(inp) is tuple:enable_mosaic_mixup = inp[0]idx = inp[1]else:enable_mosaic_mixup = Falseidx = inpimg, labels, segments, img_id = self._dataset.pull_item(idx)

若使用Mosaic数据增强,则基于random.randint随机选出三张其他的图像,将它们拼接为一张大图。四张图分别放在左上、右上、左下、右下四个位置。

    if enable_mosaic_mixup:if random.random() < self.mosaic_prob:mosaic_labels = []mosaic_segments = []input_h, input_w = self.input_dim[0], self.input_dim[1]yc = int(random.uniform(0.5 * input_h, 1.5 * input_h))xc = int(random.uniform(0.5 * input_w, 1.5 * input_w))# 3 additional image indicesindices = [idx] + [random.randint(0,len(self._dataset) - 1) for _ in range(3)]for i_mosaic, index in enumerate(indices):img, _labels, _segments, img_id = self._dataset.pull_item(index)h0, w0 = img.shape[:2]  # orig hwscale = min(1. * input_h / h0, 1. * input_w / w0)img = cv2.resize(img, (int(w0 * scale), int(h0 * scale)),interpolation=cv2.INTER_LINEAR)# generate output mosaic image(h, w, c) = img.shape[:3]if i_mosaic == 0:mosaic_img = np.full((input_h * 2, input_w * 2, c),114,dtype=np.uint8)  # pad 114(l_x1, l_y1, l_x2,l_y2), (s_x1, s_y1, s_x2, s_y2) = get_mosaic_coordinate(mosaic_img, i_mosaic, xc, yc, w, h, input_h, input_w)mosaic_img[l_y1:l_y2, l_x1:l_x2] = img[s_y1:s_y2,s_x1:s_x2]padw, padh = l_x1 - s_x1, l_y1 - s_y1

标签也进行相应的平移及尺度变换,同时若标注时有分割信息,则利用分割的标注信息对框进行再次校正。

                    labels = _labels.copy()# Normalized xywh to pixel xyxy formatif _labels.size > 0:labels[:, 0] = scale * _labels[:, 0] + padwlabels[:, 1] = scale * _labels[:, 1] + padhlabels[:, 2] = scale * _labels[:, 2] + padwlabels[:, 3] = scale * _labels[:, 3] + padhsegments = [xyn2xy(x, scale, padw, padh) for x in _segments]mosaic_segments.extend(segments)mosaic_labels.append(labels)if len(mosaic_labels):mosaic_labels = np.concatenate(mosaic_labels, 0)np.clip(mosaic_labels[:, 0],0,2 * input_w,out=mosaic_labels[:, 0])np.clip(mosaic_labels[:, 1],0,2 * input_h,out=mosaic_labels[:, 1])np.clip(mosaic_labels[:, 2],0,2 * input_w,out=mosaic_labels[:, 2])np.clip(mosaic_labels[:, 3],0,2 * input_h,out=mosaic_labels[:, 3])if len(mosaic_segments):assert input_w == input_hfor x in mosaic_segments:np.clip(x, 0, 2 * input_w,out=x)  # clip when using random_perspective()

之后对图像及标注框进行平移、缩放等仿射变换变化。

                img, labels = random_affine(mosaic_img,mosaic_labels,mosaic_segments,target_size=(input_w, input_h),degrees=self.degrees,translate=self.translate,scales=self.scale,shear=self.shear,)

若mixup_prob不为0且random.random()<mixup_prob,则对图像再次进行mixup数据增强。

            # -----------------------------------------------------------------# CopyPaste: https://arxiv.org/abs/2012.07177# -----------------------------------------------------------------if (not len(labels) == 0 and random.random() < self.mixup_prob):img, labels = self.mixup(img, labels, self.input_dim)

将标注转为BoxList格式后,进行通用数据增强,最后返回图像、标注、图像id。

            # transfer labels to BoxListh_tmp, w_tmp = img.shape[:2]boxes = [label[:4] for label in labels]boxes = torch.as_tensor(boxes).reshape(-1, 4)areas = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])valid_idx = areas > 4target = BoxList(boxes[valid_idx], (w_tmp, h_tmp), mode='xyxy')classes = [label[4] for label in labels]classes = torch.tensor(classes)[valid_idx]target.add_field('labels', classes.long())if self._transforms is not None:img, target = self._transforms(img, target)# -----------------------------------------------------------------# img_info and img_id are not used for training.# They are also hard to be specified on a mosaic image.# -----------------------------------------------------------------return img, target, img_id

将图像封装为ImageList(damo/structures/image_list.py)

在damo/dataset/collate_batch.py中的第16行以及damo/detectors/detector.py中的第54行,to_image_list(x)会将输入的tensor封装为ImageList,并Padding到size_divisible的整数倍。

    elif isinstance(tensors, (tuple, list)):if max_size is None:max_size = tuple(max(s) for s in zip(*[img.shape for img in tensors]))if size_divisible > 0:import mathstride = size_divisiblemax_size = list(max_size)max_size[1] = int(math.ceil(max_size[1] / stride) * stride)max_size[2] = int(math.ceil(max_size[2] / stride) * stride)max_size = tuple(max_size)batch_shape = (len(tensors), ) + max_sizebatched_imgs = tensors[0].new(*batch_shape).zero_()  # + 114for img, pad_img in zip(tensors, batched_imgs):pad_img[:img.shape[0], :img.shape[1], :img.shape[2]].copy_(img)image_sizes = [im.shape[-2:] for im in tensors]pad_sizes = [batched_imgs.shape[-2:] for im in batched_imgs]return ImageList(batched_imgs, image_sizes, pad_sizes)else:raise TypeError('Unsupported type for to_image_list: {}'.format(type(tensors)))

二、模型构建代码解读

detector代码解读(damo/detectors/detector.py)

__init__解读

在damo/apis/detector_trainer.py的115行执行build_local_model(self.cfg, self.device)进行模型构建,detector类作为入口,会在init中会将backbone、neck、head按照config的配置进行结构初始化。

    def __init__(self, config):super().__init__()self.backbone = build_backbone(config.model.backbone)self.neck = build_neck(config.model.neck)self.head = build_head(config.model.head)self.config = config

forward解读

在damo/apis/detector_trainer.py的第285行和第300行,模型进行前向推理,会调用forward函数。forward函数内包含蒸馏训练和普通训练两个分支。若tea为True,则为老师模型的前向推理,直接返回neck的特征。若stu为true,则开启了蒸馏训练,除了返回Head的输出外,还返回neck的输出用于特征蒸馏。

​def forward(self, x, targets=None, tea=False, stu=False):images = to_image_list(x)feature_outs = self.backbone(images.tensors)  # list of tensorfpn_outs = self.neck(feature_outs)if tea:return fpn_outselse:outputs = self.head(fpn_outs,targets,imgs=images,)if stu:return outputs, fpn_outselse:return outputs
模型结构代码(damo/base_models/backbones/, necks/, heads/)
模型结构部分直接参考论文的示意图会更加清晰。MAE-NAS构建部分,直接参考官方给出的NAS教程:https://github.com/alibaba/lightweight-neural-architecture-search/blob/main/scripts/damo-yolo/Tutorial_NAS_for_DAMO-YOLO_cn.md三、loss计算模块(damo/base_models/losses/)
loss在damo/base_models/heads/zero_head.py的111-115行被定义,包括DistributionFocalLoss、QualityFocalLoss以及GIOULoss。self.loss_dfl = DistributionFocalLoss(loss_weight=0.25)self.loss_cls = QualityFocalLoss(use_sigmoid=False,beta=2.0,loss_weight=1.0)self.loss_bbox = GIoULoss(loss_weight=2.0)
在zero_head.py的375-400行,loss计算被调用。loss_qfl为魔改版的Focal Loss,从实现上能看到,依然还是保留分类的向量,但是对应类别位置的置信度的物理含义不再是分类的score,而是改为质量预测的score。loss_dfl以类似交叉熵的形式去优化与标签 y 最接近的一左一右两个位置的概率,从而让网络快速地聚焦到目标位置的邻近区域的分布中去。loss_bbox以常用的GIOU进行loss计算。最后三者loss相加作为总的Loss返回,进行反向传播。loss_qfl = self.loss_cls(cls_scores, (labels, label_scores),avg_factor=num_total_pos)pos_inds = torch.nonzero((labels >= 0) & (labels < self.num_classes),as_tuple=False).squeeze(1)weight_targets = cls_scores.detach()weight_targets = weight_targets.max(dim=1)[0][pos_inds]norm_factor = max(reduce_mean(weight_targets.sum()).item(), 1.0)if len(pos_inds) > 0:loss_bbox = self.loss_bbox(decoded_bboxes[pos_inds],bbox_targets[pos_inds],weight=weight_targets,avg_factor=1.0 * norm_factor,)loss_dfl = self.loss_dfl(bbox_before_softmax[pos_inds].reshape(-1, self.reg_max + 1),dfl_targets[pos_inds].reshape(-1),weight=weight_targets[:, None].expand(-1, 4).reshape(-1),avg_factor=4.0 * norm_factor,)else:loss_bbox = bbox_preds.sum() / norm_factor * 0.0loss_dfl = bbox_preds.sum() / norm_factor * 0.0logger.info(f'No Positive Samples on {bbox_preds.device}! May cause performance decrease. loss_bbox:{loss_bbox:.3f}, loss_dfl:{loss_dfl:.3f}, loss_qfl:{loss_qfl:.3f} ')total_loss = loss_qfl + loss_bbox + loss_dfl
​

四、DAMO-YOLO实操

按照官网给的安装教程在我的Linux环境安装好以后,一条命令就把DAMO-YOLO训练起来啦。

python -m torch.distributed.launch --nproc_per_node=8 tools/train.py -f configs/damoyolo_tinynasL25_S.py

我使用的是v100 16G的机器,花了一天半时间完成了S的训练。跑的是非蒸馏的训练版本,精度和文章汇报的差不多。总的来说还是一个非常不错的工作,期待这个工作的持续更新。

DAMO-YOLO全流程代码解读相关推荐

  1. WGCNA分析 | 全流程代码分享 | 代码二

    – 关于WGNCA的教程,本次的共有三期教程,我们同时做了三个分析的比较,差异性相对还是比较大的,详情可看WGCNA分析 | 你的数据结果真的是准确的吗??,这里面我们只是做了输出图形的比较差异,具体 ...

  2. 【Heritrix基础教程之4】开始一个爬虫抓取的全流程代码分析

    在创建一个job后,就要开始job的运行,运行的全流程如下: 1.在界面上启动job 2.index.jsp 查看上述页面对应的源代码 <a href='"+request.getCo ...

  3. 音频数据建模全流程代码示例:通过讲话人的声音进行年龄预测

    来源:DeepHub IMBA 本文约6100字,建议阅读10+分钟 本文展示了从EDA.音频预处理到特征工程和数据建模的完整源代码演示. 大多数人都熟悉如何在图像.文本或表格数据上运行数据科学项目. ...

  4. 音频数据的建模全流程代码示例:通过讲话人的声音进行年龄预测

    大多数人都熟悉如何在图像.文本或表格数据上运行数据科学项目.但处理音频数据的样例非常的少见.在本文中,将介绍如何在机器学习的帮助下准备.探索和分析音频数据.简而言之:与其他的形式(例如文本或图像)类似 ...

  5. 267019条猫眼数据加持,原来你是这样的《流浪地球》——python数据分析全流程代码实现!

    2019年春节档,<流浪地球>横空出世,在强势口碑加持下,上映两周多票房即突破40亿! 与之相随的主题也霸占了春节期间的话题榜. 作为一部现象级的电影,笔者也很感兴趣,特意爬取了2月5日( ...

  6. 20分钟让你了解OpenGL——OpenGL全流程详细解读

    导语: 对于开发者来说,学习OpenGL或者其他图形API都不是一件容易的事情.即使是一些对OpenGL有一些经验的开发者,往往也未必对OpenGL有完整.全面的理解.市面上的OpenGL文章往往零碎 ...

  7. DGIOT平台实时展示OPC上报数据全流程代码剖析

    [小 迪 导读]:OPC软件作为工业自动化领域应用最广泛的软件,深受工业控制人员的喜爱.但也有许多情况下,OPC软件并不能满足实际的使用需求: 使用场景 1.OPC只在内网运行,希望可以将数据传递至外 ...

  8. 金融风控评分卡建模全流程!

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:桔了个仔,南洋理工大学,数据科学家 知乎丨https://zhua ...

  9. 写一段代码提高内存占用_记录一次生产环境中Redis内存增长异常排查全流程!...

    点击上方 IT牧场 ,选择 置顶或者星标 技术干货每日送达 最近 DBA 反馈线上的一个 Redis 资源已经超过了预先设计时的容量,并且已经进行了两次扩容,内存增长还在持续中,希望业务方排查一下容量 ...

最新文章

  1. 实践自定义UI—RLF...(RelativeLayout LinearLayout FrameLayout....)
  2. python打印星星居中_python中怎么打印星星
  3. SHELL 脚本小技巧
  4. WordPress 多媒体库添加分类和标签支持
  5. Golang语言slice实现原理及使用方法
  6. 导入数据_导入外部数据的三个技巧
  7. libsvm3.22——使用指南
  8. 数据解决方案:原力大数据教你如何撰写数据分析报告
  9. Google 产品速查手册大全
  10. 如何实现 java 接口中的部分方法
  11. opencv继承配库
  12. 分享8个超酷的HTML5相册动画应用
  13. Java 多个文件压缩为一个zip文件
  14. matlab如何编newton-raphson,Matlab中的Newton Raphsons方法?
  15. 去掉Win7快捷方式小箭头
  16. 计算机上键盘无法输入法,为什么键盘打不出字 大家都会用鼠标点击输入法图...
  17. Netty4实战第六章:ChannelHandler
  18. 中央空调系统运行原理以及相关设备介绍
  19. Rockchip HDMI 软件开发指南
  20. 台式计算机idc数据排名,2019年电脑销量排行_IDC:2019年中国PC市场预测销量持续走低...

热门文章

  1. 使用Python连接阿里云盘
  2. codeSourcery 交叉编译环境搭建
  3. 博弈论 斯坦福game theory stanford week 2.0_
  4. 加密U盘专业加密芯片方案
  5. 获取不带后缀名的Excel文件名Python
  6. iOS 自动生成各种尺寸的App Icon 和 Launch Image( App Icon Gear)
  7. mysql数据库从入门到高级
  8. NginxProxyManager实现unraid反向代理
  9. EasyPoi导出Excel实现标记颜色
  10. 成长中的SEO,应该避免这12个过时的优化策略(转载自:https://www.duiji.net)