目录

前言

一、make_roi_box_loss_evaluator()函数

二、FastRCNNLossComputation

1、__init__()函数

2、match_targets_to_proposals()函数

3、prepare_targets()函数

4、subsample()函数

5、__call__()函数


前言

上一篇博客已经介绍完box_head的inference文件,我们知道了box_head在inference阶段是如何进行筛选box(Proposals),最后得到输出的instances结果,本篇博客将介绍在box_head阶段的loss是如何进行计算的,有了前面RPN的loss文件介绍,box_head的loss文件介绍将会简单很多,它涉及到的函数和RPN的loss文件基本是类似的。

一、make_roi_box_loss_evaluator()函数

box_head的计算loss相关操作在your_project/maskrcnn_benchmark/modeling/roi_heads/box_head/loss.py文件中,我们首先来看看make_roi_box_loss_evaluator()这个函数:

def make_roi_box_loss_evaluator(cfg):# 匹配器 用于给RPN输出给ROI_head部分的Proposals分配真实的标签matcher = Matcher(cfg.MODEL.ROI_HEADS.FG_IOU_THRESHOLD,cfg.MODEL.ROI_HEADS.BG_IOU_THRESHOLD,allow_low_quality_matches=False,)# box的编解码器bbox_reg_weights = cfg.MODEL.ROI_HEADS.BBOX_REG_WEIGHTSbox_coder = BoxCoder(weights=bbox_reg_weights)# 在box_head预测得到的Proposals中筛选正负样本用于训练fg_bg_sampler = BalancedPositiveNegativeSampler(cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE, cfg.MODEL.ROI_HEADS.POSITIVE_FRACTION)# 这个不管它!cls_agnostic_bbox_reg = cfg.MODEL.CLS_AGNOSTIC_BBOX_REG# 损失的计算  用于计算整个box_head部分的lossloss_evaluator = FastRCNNLossComputation(matcher,fg_bg_sampler,box_coder,cls_agnostic_bbox_reg)return loss_evaluator

这部分代码是不是跟RPN中loss文件很相似,不能说相似,简直一模一样~,因为RPN的输出是Proposals,这些Proposals是作为ROI_heads的一个输入,但是在训练阶段,box_head部分将对这些Proposals选一部分用作训练,作为box_head的输入。

从上面代码中可以看出,整个函数主要由三个类的对象构成,这三个类分别是:

Matcher类:这个类主要是给RPN输出的Proposals分配对应类别标签的。
BalancedPositiveNegativeSampler类:用于筛选上述的哪些Proposals可以当作正负样本用于计算loss的过程。
FastRCNNLossComputation:用于给筛选(和inference阶段筛选机制不一样)过后得到的Proposals计算其对应的loss。

因为Match类BalancedPositiveNegativeSampler类已经在:

maskrcnn-benchmark-master(六):RPN的loss文件

介绍过了,所以本篇只着重介绍FastRCNNLossComputation类

二、FastRCNNLossComputation

在FastRCNNLossComputation类中主要包含有五个函数,它们分别是:__init__()函数、match_targets_to_proposals()函数、prepare_targets()函数、subsample()函数、__call__()函数,它们之间的简单调用关系如下图所示:

1、__init__()函数

我们首先看一下__init__()函数:

class FastRCNNLossComputation(object):"""Computes the loss for Faster R-CNN.Also supports FPN对Faster-RCNN部分的loss进行计算"""def __init__(self,proposal_matcher,fg_bg_sampler,box_coder,cls_agnostic_bbox_reg=False):"""Arguments:proposal_matcher (Matcher)fg_bg_sampler (BalancedPositiveNegativeSampler)box_coder (BoxCoder)"""# 定义用于Proposals标签匹配的 匹配器self.proposal_matcher = proposal_matcher# 定义用于正负样本筛选的 筛选器self.fg_bg_sampler = fg_bg_sampler# 定义box的编解码器self.box_coder = box_coderself.cls_agnostic_bbox_reg = cls_agnostic_bbox_reg

2、match_targets_to_proposals()函数

__init__()函数主要是定义相关的类变量,没有什么好介绍的,下面来看一下match_targets_to_proposals()函数:

    def match_targets_to_proposals(self, proposal, target):# gt 和 RPN输出的Proposals之间的 IOU矩阵match_quality_matrix = boxlist_iou(target, proposal)# 预测边框和对应的gt的索引, 背景边框为-2 , 模糊边框为-1 # eg:matched_idxs[4] = 6 :表示第5个预测边框所分配的GT的id为6matched_idxs = self.proposal_matcher(match_quality_matrix)# Fast RCNN only need "labels" field for selecting the targets、# 获得 GT 的类别标签target = target.copy_with_fields("labels")# get the targets corresponding GT for each proposal# NB: need to clamp the indices because we can have a single# GT in the image, and matched_idxs can be -2, which goes# out of bounds# 将所有的背景边框和模糊边框的标签都对应成第一个gt的标签# 其实就是将target中的box 和label按照Proposals的对应顺序重新排序的一个过程,# 将target中box顺序和matched_idxs中的GT的id顺序保持一致matched_targets = target[matched_idxs.clamp(min=0)]# 将对应的列表索引添加至gt列表中matched_targets.add_field("matched_idxs", matched_idxs)return matched_targets

3、prepare_targets()函数

由此我们可以看出match_targets_to_proposals()函数返回的是一个BoxList对象,这个对象中的box是Proposals所对应的GT的box,labels是Proposals所对应GT的label。

接下来我们开看看prepare_targets()函数:

# 准备类别标签和box偏移量标签def prepare_targets(self, proposals, targets):# 类别标签列表labels = []# 回归box标签列表regression_targets = []# 分别对每一张图片进行操作for proposals_per_image, targets_per_image in zip(proposals, targets):matched_targets = self.match_targets_to_proposals(proposals_per_image, targets_per_image)matched_idxs = matched_targets.get_field("matched_idxs")# 获取每一个target所对应的label标签labels_per_image = matched_targets.get_field("labels")labels_per_image = labels_per_image.to(dtype=torch.int64)# Label background (below the low threshold)# 背景标签bg_inds = matched_idxs == Matcher.BELOW_LOW_THRESHOLDlabels_per_image[bg_inds] = 0# Label ignore proposals (between low and high thresholds)# 被忽视的样本ignore_inds = matched_idxs == Matcher.BETWEEN_THRESHOLDSlabels_per_image[ignore_inds] = -1  # -1 is ignored by sampler# compute regression targets# 计算偏移量target  因为网络预测的结果是偏移量,所以需要生成偏移量标签regression_targets_per_image = self.box_coder.encode(matched_targets.bbox, proposals_per_image.bbox)# 对生成好的类别标签和偏移量标签进行保存labels.append(labels_per_image)regression_targets.append(regression_targets_per_image)return labels, regression_targets

4、subsample()函数

上面的prepare_targets()函数就是返回为Proposals匹配好的类别标签和box偏移量标签,接下来将通过subsample()进行正负样本的筛选,我们来看看相关代码:

    def subsample(self, proposals, targets):"""This method performs the positive/negative sampling, and returnthe sampled proposals.Note: this function keeps a state.Arguments:proposals (list[BoxList])targets (list[BoxList])"""# 获取Proposals分配好的标签labels, regression_targets = self.prepare_targets(proposals, targets)# 获取被分配为正负样本的索引  由BalancedPositiveNegativeSampler类进行分配sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)proposals = list(proposals)# add corresponding label and regression_targets information to the bounding boxesfor labels_per_image, regression_targets_per_image, proposals_per_image in zip(labels, regression_targets, proposals):# 给BoxList类型的Proposals添加标签信息proposals_per_image.add_field("labels", labels_per_image)proposals_per_image.add_field("regression_targets", regression_targets_per_image)# distributed sampled proposals, that were obtained on all feature maps# concatenated via the fg_bg_sampler, into individual feature map levels# 对BoxList类型的Proposals进行正负样本筛选(对应的标签也会一并被筛选出来)for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):img_sampled_inds = torch.nonzero(pos_inds_img | neg_inds_img).squeeze(1)proposals_per_image = proposals[img_idx][img_sampled_inds]proposals[img_idx] = proposals_per_image# 得到筛选之后的Proposals(BoxList对象 其中包含有label信息)self._proposals = proposalsreturn proposals

5、__call__()函数

通过subsample()筛选得到可以用于训练阶段的Proposals之后(注意这些Proposals都是从RPN输出的Proposals中进行筛选的),就要进行最后loss计算工作了,我们来看一下__call__()函数:

    def __call__(self, class_logits, box_regression):"""Computes the loss for Faster R-CNN.This requires that the subsample method has been called beforehand.Arguments:class_logits (list[Tensor])box_regression (list[Tensor])Returns:classification_loss (Tensor)box_loss (Tensor)"""# 预测的Proposals类别class_logits = cat(class_logits, dim=0)# 预测的Proposals box偏移量box_regression = cat(box_regression, dim=0)device = class_logits.deviceif not hasattr(self, "_proposals"):raise RuntimeError("subsample needs to be called before")# 获取用于box head训练阶段输入的Proposals和它对应的标签proposals = self._proposals# 获取proposals对应的真实类别标签labels = cat([proposal.get_field("labels") for proposal in proposals], dim=0)# 获取proposals对应的真实box 偏移量regression_targets = cat([proposal.get_field("regression_targets") for proposal in proposals], dim=0)# 计算类别分类lossclassification_loss = F.cross_entropy(class_logits, labels)# get indices that correspond to the regression targets for# the corresponding ground truth labels, to be used with# advanced indexing# 不对负样本的box进行回归loss计算  所以选出正样本的索引sampled_pos_inds_subset = torch.nonzero(labels > 0).squeeze(1)labels_pos = labels[sampled_pos_inds_subset]if self.cls_agnostic_bbox_reg:map_inds = torch.tensor([4, 5, 6, 7], device=device)else:map_inds = 4 * labels_pos[:, None] + torch.tensor([0, 1, 2, 3], device=device)# 计算box 偏移量的回归lossbox_loss = smooth_l1_loss(box_regression[sampled_pos_inds_subset[:, None], map_inds],regression_targets[sampled_pos_inds_subset],size_average=False,beta=1,)box_loss = box_loss / labels.numel()return classification_loss, box_loss

至此box_head的loss文件就算介绍完了,总结一下整个过程就是:

1、给每个Proposals匹配对应的类别标签和box标签,进而计算出box偏移量的回归标签。

2、在对这些匹配好标签的Proposals筛选正负样本。(只有提前匹配好标签才知道哪些是正类哪些是负类嘛)

3、通过网络对Proposals的最后的分类结果和box偏移量的回归结果,结合匹配好的标签计算loss。

同时box_head部分算是已经介绍完了,下一次将展开mask_head的介绍,待续~

码字不易  未经许可  请勿随意转载!

maskrcnn-benchmark-master(十):box_head的loss文件相关推荐

  1. ComicEnhancerPro 系列教程十八:JPG文件长度与质量

    作者:马健 邮箱:stronghorse_mj@hotmail.com 主页:http://www.comicer.com/stronghorse/ 发布:2017.07.23 教程十八:JPG文件长 ...

  2. a4如何打印双面小册子_怎样将十几几十页的长文件文档打印成A4纸对折的小册子?...

    概述 工作和生活中总是会收到十几几十页的长文件或者长文档,需要打印出来传阅的时候,或者打印保存的时候,如果直接用A4纸打印出来,厚厚的一叠,占用地方.不便携带.即使双面打印,也不见得很省.不妨试试本文 ...

  3. 《PyInstaller打包实战指南》第十六节 单文件模式打包PyGame

    第十六节 单文件模式打包PyGame 打包示例源码下载: M to the B / Coffee Breakout · GitLab 版本信息: pygame==1.9.6 pyinstaller== ...

  4. 第十二章 Python文件操作【转】

    12.1 open() open()函数作用是打开文件,返回一个文件对象. 用法格式:open(name[, mode[, buffering[,encoding]]]) -> file obj ...

  5. Python的学习(十六):对文件的操作

    Python的学习(十六):对文件的操作 编码格式的介绍 Python中的解释器使用的是Unicode(内存) .py文件在磁盘上使用UTF-8存储(外存) 如何修改文件格式?不写的话默认为UTF-8 ...

  6. linux 初始化文件失败,linux(十)之初始化文件

    前面写了很多linux的知识,其实很多都是命令的,所以要去多多的练习才能学的更好,加油为了好工作. 要么现在懒惰,未来讨饭.要么现在努力,未来惬意. 一.初始化文件概述 1.1.概述 系统初始化文件是 ...

  7. Android开发笔记(七十四)布局文件优化

    include/merge 布局优化中常常用到include/merge标签,include的含义类似C代码中的include,意思是直接把指定布局片段包含进当前的布局文件.include适用于多个布 ...

  8. StalinLocker:一款会在十分钟之后删除文件和数据的勒索软件

    如果你的电脑屏幕上出现了以下界面,千万不要以为是一款什么游戏的广告. 这是Canthink网络安全攻防实验室新发现的一个名为"StalinLocker (斯大林锁屏者)"或&quo ...

  9. 哭瞎!360云盘将关停,你的几十T照片和文件该怎么办

    IDO老徐刚得到了一个非常不开心的消息,360云盘将停止个人云盘服务... 进行业务转型,在网盘存储.传播内容的合法性和安全性得到彻底解决之前不再考虑恢复,之后转型企业云服务. 而且之前共享的所有资料 ...

最新文章

  1. 异常: cv::Exception,位于内存位置 0x00000059E67CE590 处。
  2. python运行系统找不到指定文件_“系统无法找到指定的文件”当调用Python中的subprocess.Popen...
  3. Linux学习之命令【1】
  4. Spring Environment仅用于初始化代码
  5. css3 animation 动画属性简介
  6. 数据库工具一段时间后打开报错:远程过程调用失败0x800706be
  7. Python机器学习:PCA与梯度上升:04求数据的前n个主成分
  8. 音视频开发(42)---H.264 SVC 简介
  9. 第八届“图灵杯”NEUQ-ACM程序设计竞赛个人赛——G题 贪吃的派蒙
  10. Windows 11 即将发布,微软“强推” Edge 浏览器?
  11. php链接没有下划线,html超链接怎么去掉下划线
  12. 3月4日 第1人称相机世界的坐标系,焦距、焦点、调焦、超焦距、视场角、滑动变焦Dolly zooming,相机内参
  13. 利用ArcGIS Python批量拼接遥感影像(arcpy batch processing)
  14. 列表页——基于Django框架的天天生鲜电商网站项目系列博客(九)
  15. 2020-08-05流量计怎么选你学会了么?
  16. 好看的html页脚,Photoshop教程:设计非常漂亮的网页页脚
  17. Skynet服务器框架系列教程,skynet 服务端框架安装/运行
  18. Sensor Flicker (Sensor banding现象)
  19. Emoji-Chat emoji表情包发送及显示兼容web端、移动端
  20. Comet OJ - 2019国庆欢乐赛 G-字符串(后缀数组)

热门文章

  1. 袁琳 湖北 计算机 导师,袁琳
  2. JavaScript代码块(代码段)
  3. 关于台账自动化管理的分享交流会
  4. 互联网商业竞争中的“狼”与“羊”
  5. U盘装PE,U盘安装WIN7 ISO镜像文件
  6. 《中国大数据发展指数报告(2018年)》全文出炉!(附下载)
  7. AI+影像赛道开启,美图在人工智能领域如何「名利双收」?【楚才国科】
  8. 2019贵州大学计算机专业收分,贵州大学分数线2019
  9. 想象力比知识更重要么?提出问题比解决问题更重要?
  10. MySQL——decimal类型长度