前言

本篇将介绍build_roi_box_head()函数,这个函数是在your_project/maskrcnn_benchmark/modeling/roi_heads/box_head/box_head.py文件中,

def build_roi_box_head(cfg, in_channels):"""Constructs a new box head.By default, uses ROIBoxHead, but if it turns out not to be enough, just register a new classand make it a parameter in the config"""# 主要返回一个ROIBoxHead类对象return ROIBoxHead(cfg, in_channels)

一、ROIBoxHead类

从代码可知,build_roi_box_head()主要是返回一个ROIBoxHead类对象,看来我们主要需要了解的目标就是这个ROIBoxHead类了,这个类也是在your_project/maskrcnn_benchmark/modeling/roi_heads/box_head/box_head.py文件中,我们首先看一下__init__()函数:

class ROIBoxHead(torch.nn.Module):"""Generic Box Head class."""def __init__(self, cfg, in_channels):super(ROIBoxHead, self).__init__()# ROI层中的特征提取器(先进行ROI Align,后续有没有特征提取操作看具体head的方法)# 因为RPN提取的Proposals大小都不太一样,为了使得这些Proposals的图片特征大小一样,# 需要进行ROI Align操作得到大小一样的特征。self.feature_extractor = make_roi_box_feature_extractor(cfg, in_channels)# ROI层中的边框预测类(用于类别的分类和box的回归~)self.predictor = make_roi_box_predictor(cfg, self.feature_extractor.out_channels)# 下面这两个和RPN中的很相像# ROI层中的后处理类(inference过程 进行NMS操作和box解码等操作)self.post_processor = make_roi_box_post_processor(cfg)# 训练过程计算lossself.loss_evaluator = make_roi_box_loss_evaluator(cfg)

接着我们看一下该类的forward()函数,了解该类的一个处理流程:

    def forward(self, features, proposals, targets=None):"""Arguments:features (list[Tensor]): feature-maps from possibly several levelsproposals (list[BoxList]): proposal boxestargets (list[BoxList], optional): the ground-truth targets.Returns:x (Tensor): the result of the feature extractorproposals (list[BoxList]): during training, the subsampled proposalsare returned. During testing, the predicted boxlists are returnedlosses (dict[Tensor]): During training, returns the losses for thehead. During testing, returns an empty dict.x是特征提取器提取的特征proposals分为两种情况:1、在训练阶段,这是采样好用于训练的Proposals2、在测试阶段,这是预测好的boxlistsloss也分两种情况:  1、在训练阶段,这是box_head的模块的loss值。2、在测试阶段,这是一个空的字典。"""if self.training:# Faster R-CNN subsamples during training the proposals with a fixed# positive / negative ratio# 筛选用于训练阶段计算loss的Proposalswith torch.no_grad():proposals = self.loss_evaluator.subsample(proposals, targets)# extract features that will be fed to the final classifier. The# feature_extractor generally corresponds to the pooler + heads# feature_extractor是用来提取特征传输给最终的分类器# feature_extractor是由pooler 层 + heads 组成的x = self.feature_extractor(features, proposals)# final classifier that converts the features into predictions# 分类器进行最后的预测class_logits, box_regression = self.predictor(x)if not self.training:# 如果不是训练阶段,则要对预测的结果进行后处理  最后输出检测结果result = self.post_processor((class_logits, box_regression), proposals)return x, result, {}# 计算box的回归loss和类别的分类lossloss_classifier, loss_box_reg = self.loss_evaluator([class_logits], [box_regression])return (x,proposals,dict(loss_classifier=loss_classifier, loss_box_reg=loss_box_reg),)

由ROIBoxHead类可以看出来,主要涉及了四个函数:

1、make_roi_box_feature_extractor():包含有ROI Align操作,用来提取size一致的特征。

2、make_roi_box_predictor():feature_extractor提取的特征进行类别分类和box的回归。

3、make_roi_box_post_processor():如果是inference过程,通过该函数对预测的结果进行筛选,输出最终的检测结果(RPN的RPNPostProcessor类是不是很相似~)

4、make_roi_box_loss_evaluator():如果是训练过程,通过该函数对预测的结果筛选出正负样本用于计算box_head模块的loss。

接下来我将一一介绍(整体结构简图如下所示):

二、make_roi_box_feature_extractor()

该函数在your_project/maskrcnn_benchmark/modeling/roi_heads/box_head/roi_box_feature_extractors.py文件中,我们来看看相关代码:

def make_roi_box_feature_extractor(cfg, in_channels):# 使用注册器获取该ROI_BOX_FEATURE_EXTRACTORS模块的对象# 对应的ROI_BOX_FEATURE_EXTRACTORS模块都定义在该函数上面func = registry.ROI_BOX_FEATURE_EXTRACTORS[cfg.MODEL.ROI_BOX_HEAD.FEATURE_EXTRACTOR]return func(cfg, in_channels)

使用注册器来获取参数中定义的feature_extractor,我们在该文件中找到了三种feature_extractor,它们分别是:ResNet50Conv5ROIFeatureExtractorFPN2MLPFeatureExtractorFPNXconv1fcFeatureExtractor。如果你自己想要重新定义一个,你也可以按照这些类的结构,重新写一个feature extractor,并给该类注册对应的名称,在参数文件对应位置使用该名称。我接下来对稍微简单一些的FPN2MLPFeatureExtractor类的代码做一个介绍:

# 在注册器中进行注册
@registry.ROI_BOX_FEATURE_EXTRACTORS.register("FPN2MLPFeatureExtractor")
class FPN2MLPFeatureExtractor(nn.Module):"""Heads for FPN for classification"""def __init__(self, cfg, in_channels):super(FPN2MLPFeatureExtractor, self).__init__()# Proposals经过ROI Align之后得到size大小resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTIONscales = cfg.MODEL.ROI_BOX_HEAD.POOLER_SCALESsampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO# 进行ROI Align操作pooler = Pooler(output_size=(resolution, resolution),scales=scales,sampling_ratio=sampling_ratio,)# ROI Align之后得到的维度input_size = in_channels * resolution ** 2# 全连接层的输出维度representation_size = cfg.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIMuse_gn = cfg.MODEL.ROI_BOX_HEAD.USE_GN# 定义ROI Align的类变量self.pooler = pooler# 定义全连接层的类变量self.fc6 = make_fc(input_size, representation_size, use_gn)self.fc7 = make_fc(representation_size, representation_size, use_gn)# 提取特征之后得到最终的输出维度self.out_channels = representation_size# 进行提取特征操作def forward(self, x, proposals):# ROI Align操作x = self.pooler(x, proposals)# 进行展平 作为全连接层的输出x = x.view(x.size(0), -1)# 进行全连接层操作x = F.relu(self.fc6(x))x = F.relu(self.fc7(x))# 返回提取的特征return x

总的来看feature_extractor的相关代码相对还是比较好懂的。

三、 make_roi_box_predictor()

接下来看用作类别分类判断和box回归的roi_box_predictor()函数,这个函数在your_project/maskrcnn_benchmark/modeling/roi_heads/box_head/roi_box_predictors.py文件中:

def make_roi_box_predictor(cfg, in_channels):func = registry.ROI_BOX_PREDICTOR[cfg.MODEL.ROI_BOX_HEAD.PREDICTOR]return func(cfg, in_channels)

可以看出这个函数和make_roi_box_feature_extractor()函数基本类似,都是通过注册器获取相关所需的对象,因此我们将重点关注到该函数上方被注册的2个predictor类上:FastRCNNPredictor类和FPNPredictor类,下面就以FastRCNNPredictor类的代码为例,做简要的介绍:

# 对特征先进行池化,再使用边框分类器进行分类和边框回归器进行回归
# 首先在注册器ROI_BOX_PREDICTOR上注册该类
@registry.ROI_BOX_PREDICTOR.register("FastRCNNPredictor")
class FastRCNNPredictor(nn.Module):def __init__(self, config, in_channels):super(FastRCNNPredictor, self).__init__()assert in_channels is not None# 输入维度num_inputs = in_channels# 分类的类别数= 类别数 + 1(背景)num_classes = config.MODEL.ROI_BOX_HEAD.NUM_CLASSES# 进行全局平均池化self.avgpool = nn.AdaptiveAvgPool2d(1)# 全连接层用于分类self.cls_score = nn.Linear(num_inputs, num_classes)num_bbox_reg_classes = 2 if config.MODEL.CLS_AGNOSTIC_BBOX_REG else num_classes# 全连接层用于box的坐标回归self.bbox_pred = nn.Linear(num_inputs, num_bbox_reg_classes * 4)# 类别分类参数初始化nn.init.normal_(self.cls_score.weight, mean=0, std=0.01)nn.init.constant_(self.cls_score.bias, 0)# box回归参数初始化nn.init.normal_(self.bbox_pred.weight, mean=0, std=0.001)nn.init.constant_(self.bbox_pred.bias, 0)# 执行过程def forward(self, x):# 平均池化x = self.avgpool(x)# 展平x = x.view(x.size(0), -1)# 使用全连接层进行分类cls_logit = self.cls_score(x)# 使用全连接层进行box回归bbox_pred = self.bbox_pred(x)# 返回结果return cls_logit, bbox_pred

总的来看roi_box_predictor相关代码也是比较好懂的,就是正常pytorch写的网络结构代码。

接下来将会介绍 make_roi_box_post_processor()函数,它是在box_head的inference.py文件中,以及make_roi_box_loss_evaluator()函数,它是在box_head的loss.py中,由于这两个部分的内容有些多,所以决定放到下个博客进行介绍:

maskrcnn-benchmark-master(九):box_head的inference文件

maskrcnn-benchmark-master(十):box_head的loss文件

待续~

maskrcnn-benchmark-master(八):build_roi_box_head()函数相关推荐

  1. Python第八课(函数1)

    Python第八课(函数1)    >>>转到思维导图>>>转到中二青年 函数的返回值 函数内要想返回给调用者值 必须用关键字return 不写return:函数默 ...

  2. excel最常用的八个函数_Excel最常用的几个函数,我都帮你整理好了!

    计算机二级考试中Excel表格经常需要用到函数公式,有些小伙伴经常会把函数公式给混淆. 在备考二级的过程中,我也经常会遇到这种情况:所以,在学习函数公式的过程中,我打算把这些公式都写下来. 我希望自己 ...

  3. JavaScript学习笔记(八)--- 函数表达式

    1.递归 实现一: function factorial(num){if(num<=1){return 1;}else{return num*factorial(num-1);} } alert ...

  4. shell编程入门步步高(八、函数)

    函数 函数是一些命令的集合,作用是让程序模块化. 语法: fuction 函数名() { 命令 } 或者 函数名() { 命令 } 或者 fuction 函数名 { 命令 }

  5. excel最常用的八个函数_Excel中最常用的快捷键

    我们的学习和工作中,Excel都是必不可少的工具.掌握一些常用的Excel快捷键对提高我们的工作效率很有帮助,下面就给大家列举一些在Excel中比较常用的快捷键,各位同学可以在闲暇之余了解一下. Ct ...

  6. vue源码解析pdf_Vue源码全面解析八 callHook函数(触发生命周期函数)

    首先我们打开'src/core/instance/lifecycle.js'文件,代码如下: export function callHook (vm: Component, hook: string ...

  7. Python(八) 函数、模块

    函数 定义函数 1.意义:函数是实现某个功能的一些代码,提高代码的复用性. 2.定义:用def关键字定义函数, 3.函数组成:函数由函数名.形参.函数体.调用函数(里面会有函数体)组成 4.要使用函数 ...

  8. 回调函数例子_Linux C - C基础篇八(函数)

    概念 函数可以被看作是一个由用户定义的一系列操作的集合.一般来说,函数用一个名字来表示.函数的操作数,称为参数,由一个位于括号中,并且用逗号分隔的参数列表指定,如果该函数没有参数需要传入,则这个列表为 ...

  9. C++ Primer Plus学习(八)——函数进阶

    函数进阶 内联函数 引用变量 默认参数 函数重载 函数模板 总结 C++还提供了许多新的函数特性,使之有别于C语言.新特性包括内联函数.按引用传递变量.默认的参数值.函数重载(多态)以及模板函数. 内 ...

最新文章

  1. 如何更改gridview中任意单元格颜色或者内容。
  2. 无法打开文件“opencv_world410d.obj”
  3. retinaface极坐标
  4. Android应用程序的五大基本组件
  5. PDF.NET数据开发框架实体类操作实例(for PostgreSQL,并且解决自增问题)
  6. 操作系统-银行家算法
  7. 使用JavaScript的FormData向SAP ABAP系统发起登录请求
  8. 【渝粤教育】国家开放大学2019年春季 4有机合成单元反应 参考试题
  9. c语言打印删除空格,新人提问:如何将输出时每行最后一个空格删除
  10. 【C语言】矩阵的最大值(指针专题)
  11. 申通快递:1月快递服务业务收入25.65亿元 同比增长21.27%
  12. shell sort 最后一列排序_十个必知的排序算法|Python实例系列[1]
  13. 北京理工大学计算机学院杨晨,杨旭_北京理工大学计算机学院
  14. Java并发编程实战读书笔记5 ---Executor在android中的应用
  15. js 图片浏览插件原生
  16. cocos2d-js 的 cc.callFunc 参数
  17. .pth.tar文件
  18. Qt中关于emit和moc_*.cpp的自动生成
  19. linux目录和文件
  20. 阿里巴巴稀疏模型训练引擎-DeepRec

热门文章

  1. 笔记本什么牌子好?笔记本电脑性价比排行2020
  2. 光纤光信号闪红灯_光猫的光信号灯是闪烁的红灯可能是什么原因造成的?
  3. 湖北理工学院计算机二级成绩查询,湖北理工学院教务管理系统入口http://jwc.hbpu.edu.cn/...
  4. Linux报bus error(总线错误)解决办法
  5. 涂鸦Zigbee SDK开发系列教程——3.快速入门
  6. 什么是User Story
  7. 关于Windows 10内置应用卸载路径
  8. iOS SDK开发步骤
  9. Vue项目打包部署教程及常见错误-前端开发
  10. Lua中获取字符串长度整理