实时语义分割网络 BiSeNet(附代码解读)
实时语义分割网络 BiSeNet
- BiSeNet
- Contributions
- BackGround
- BiSeNet 结构
- Loss function
- Experimental Results
- 采用数据集
- 一些实验结果
BiSeNet
Contributions
- 提出了一种包含空间路径(SP)和上下文路径(CP)的双边分割网络(BiSeNet), 将空间信息保存和接受域提供的功能解耦成两条路径。
- 提出了特征融合模块(FFM)和注意细化模块(ARM),以在可接受的成本下进一步提高精度。
- 在cityscape、CamVid和COCO-Stuff的基准上取得了令人印象深刻的成绩。具体来说,在105fps的城市景观测试数据集上,我们得到了68.4%的结果
BackGround
作者对比了当前用于三种用于加速模型的实时语义分割算法:
- 图(a)左侧所示,通过裁剪图片降低尺寸和计算量,但是会丢失大量边界信息和可视化精度。
- 图(a)右侧所示,通过修建/减少卷积过程中的通道数目,提高推理速度。其中的红色方框部分,是作者提及的ENet建议放弃模型的最后阶段(downsample操作),模型的接受域不足以覆盖较大的对象导致的识别能力较差。
- 图形(b)所示为U型的编码,解码结构,通过融合骨干网的细节,u型结构提高了空间分辨率,填补了一些缺失的细节,但是作者认为在u型结构中,一些丢失的空间信息无法轻易回复,不是根本的解决方案。
BiSeNet 结构
Spatial Path
绿色部分表示空间路径,每一层包括一个stride = 2的卷积,接着是批处理归一化和ReLU激活函数,总共三层,故而提取的特征图尺寸是原始图像的1/8。
'''code'''
class SpatialPath(nn.Module):def __init__(self, *args, **kwargs):super(SpatialPath, self).__init__()self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)self.init_weight()def forward(self, x):feat = self.conv1(x) # (N, 3, H, W)feat = self.conv2(feat) # (N, 64, H/2, W/2)feat = self.conv3(feat) # (N, 64, H/4, W/4)feat = self.conv_out(feat) # (N, 128, H/8, W/8)return feat
Context Path
第二个虚线框部分是上下文路径,用于提取上下文信息,利用轻量级模型和全局平均池进行下采样。作者在轻量级模型的尾部添加一个全局平均池,提供具有全局上下文信息的最大接收字段, 并且使用U型结构来融合最后两个阶段的特征,这是一种不完整的U型结构。作者使用了Xception作为上下文路径的主干。
'''code'''
class ContextPath(nn.Module):def __init__(self, *args, **kwargs):super(ContextPath, self).__init__()self.resnet = Resnet18()self.arm16 = AttentionRefinementModule(256, 128) # 先看下面的ARM的代码self.arm32 = AttentionRefinementModule(512, 128)self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)self.up32 = nn.Upsample(scale_factor=2.) # 上采样 X2self.up16 = nn.Upsample(scale_factor=2.)self.init_weight()def forward(self, x):feat8, feat16, feat32 = self.resnet(x) '''feat8 : (N, 128, H/8, W/8)feat16 : (N, 256, H/16, W/16)feat32 : (N, 512, H/32, W/32)'''avg = torch.mean(feat32, dim=(2, 3), keepdim=True) # 全局平均池化avg = self.conv_avg(avg) # (N, 128, 1, 1)feat32_arm = self.arm32(feat32) # (N, 128, 1, 1)feat32_sum = feat32_arm + avg # (N, 128, H/32, W/32) feat32_up = self.up32(feat32_sum) # (N, 128, H/16, W/16)feat32_up = self.conv_head32(feat32_up)feat16_arm = self.arm16(feat16) # (N, 256, H/16, W/16)feat16_sum = feat16_arm + feat32_up # (N, 256, H/16, W/16)feat16_up = self.up16(feat16_sum) # (N, 256, H/8, W/8)feat16_up = self.conv_head16(feat16_up)return feat16_up, feat32_up # x8, x16
AttentionRefinementModule
看起来有点像通道注意力机制。
'''ARM Code'''
class AttentionRefinementModule(nn.Module): # ARM def __init__(self, in_chan, out_chan, *args, **kwargs):super(AttentionRefinementModule, self).__init__()self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) # 一个拥有Conv,BN,RELU的blockself.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)self.bn_atten = BatchNorm2d(out_chan)self.sigmoid_atten = nn.Sigmoid()self.init_weight()def forward(self, x): # 假设输入为(N, 3, 14, 14)feat = self.conv(x) # (N, 64, 14, 14)atten = torch.mean(feat, dim=(2, 3), keepdim=True) # (N, 64, 1, 1)atten = self.conv_atten(atten) # (N, 64, 1, 1)atten = self.bn_atten(atten) # (N, 64, 1, 1)atten = self.sigmoid_atten(atten) # (N, 64, 1, 1)out = torch.mul(feat, atten) # (N, 64, 14, 14)return out
FeatureFusionModule
FFM:作者认为空间路径可以编码丰富的空间信息和细节信息,而上下文路径提供大的接受场,主要对上下文信息进行编码,也就是说空间路径的输出是低水平而上下文路径的输出是高水平的,2条路径的特征在特征表示的层次上是不同的。因此提出了一个特征融合模块用于融合这些特征。作者仿照SENet(通道注意力机制对特征进行重新加权),即特征的选择和组合。
'''code'''
class FeatureFusionModule(nn.Module):def __init__(self, in_chan, out_chan, *args, **kwargs):super(FeatureFusionModule, self).__init__()self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)self.conv1 = nn.Conv2d( out_chan, out_chan//4,kernel_size = 1,stride = 1,padding = 0,bias = False)self.conv2 = nn.Conv2d( out_chan//4,out_chan,kernel_size = 1,stride = 1,padding = 0,bias = False)self.relu = nn.ReLU(inplace=True)self.sigmoid = nn.Sigmoid()self.init_weight()def forward(self, fsp, fcp):fcat = torch.cat([fsp, fcp], dim=1) # concat the feature from Spatial Path and Context Pathfeat = self.convblk(fcat) # (N, C, H, W)atten = torch.mean(feat, dim=(2, 3), keepdim=True) # (N, C, 1, 1)atten = self.conv1(atten) # (N, C/4, 1, 1)atten = self.relu(atten)atten = self.conv2(atten) # (N, C, 1, 1)atten = self.sigmoid(atten)feat_atten = torch.mul(feat, atten) # (N, C, H, W)feat_out = feat_atten + featreturn feat_out
BiSeNet V1
class BiSeNetV1(nn.Module):def __init__(self, n_classes, output_aux=True, *args, **kwargs):super(BiSeNetV1, self).__init__()self.cp = ContextPath()self.sp = SpatialPath()self.ffm = FeatureFusionModule(256, 256)self.conv_out = BiSeNetOutput(256, 256, n_classes, up_factor=8)''' BiSeNetOutput: input: (c, h, w) output:(n_classes, 8h, 8w) '''self.output_aux = output_auxif self.output_aux:self.conv_out16 = BiSeNetOutput(128, 64, n_classes, up_factor=8)self.conv_out32 = BiSeNetOutput(128, 64, n_classes, up_factor=16)self.init_weight()def forward(self, x):H, W = x.size()[2:]feat_cp8, feat_cp16 = self.cp(x) # (N, 256, H/8, W/8), (N, 256, H/16, W/16)feat_sp = self.sp(x) # (N, 256, H/8, W/8)feat_fuse = self.ffm(feat_sp, feat_cp8) # (N, 256, H/8, W/8)feat_out = self.conv_out(feat_fuse) # (N, 19, H, W)if self.output_aux:feat_out16 = self.conv_out16(feat_cp8) # (N, 19, H, W)feat_out32 = self.conv_out32(feat_cp16) # (N, 19, H, W)return feat_out, feat_out16, feat_out32# feat_out = feat_out.argmax(dim=1)return feat_out
Loss function
- 利用主损失函数监督整个双组网的输出。
- 增加了两个具体的辅助损耗函数来监督上下文路径的输出。
- 所有的损失函数均为softmax损失。
- 通过参数α来平衡主损失函数和辅助损失函数的权重。(文中的α等于1。该接头损耗使得优化器在优化模型时更加舒适。)
class OhemCELoss(nn.Module):def __init__(self, thresh, ignore_lb=255):super(OhemCELoss, self).__init__()self.thresh = -torch.log(torch.tensor(thresh, requires_grad=False, dtype=torch.float)) # .cuda()self.ignore_lb = ignore_lbself.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none')def forward(self, logits, labels):n_min = labels[labels != self.ignore_lb].numel() // 16 loss = self.criteria(logits, labels).view(-1)loss_hard = loss[loss > self.thresh]if loss_hard.numel() < n_min:loss_hard, _ = loss.topk(n_min)return torch.mean(loss_hard)criteria_pre = OhemCELoss(0.7)
criteria_aux = [OhemCELoss(0.7) for _ in range(2)] # to cal the loss of feat_out16 and feature out 32
Experimental Results
采用数据集
- Cityscapes
- CamVid
- COCO-Stuff
一些实验结果
- U型结构的消融实验。
- 空间路径的消融实验: 使用空间路径将性能从66.01%提高到67.42%,下图显示BiSeNet可以比只是用U-Shape获得更详细的空间信息,例如一些交通标志。
- Else:
实时语义分割网络 BiSeNet(附代码解读)相关推荐
- Real_time实时语义分割网络 SegNet, ENet, ICNet, BiSeNet,ShelfNet
1. SegNet 论文地址:A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation 本不应该将segnet作 ...
- 北航、旷视联合,打造最强实时语义分割网络
来源:AI科技评论 编辑:Camel 导语:MSFNet在Cityscapes测试集上达到77.1%mIoU/41FPS(注意是1024*2048),在Camvid测试集上达到75.4 mIoU/97 ...
- 语义分割网络-BiSenet
Sementic Segmentation-BiSenet 语义分割网络-BiSenet(Sementic Segmentation-BiSenet) 介绍 思路来源 关于感受野 关于空间信息 网络框 ...
- 【论文阅读--实时语义分割】BiSeNet V2: Bilateral Network with Guided Aggregation
摘要 低层细节和高层语义对于语义分割任务都是必不可少的.然而,为了加快模型推理的速度,目前的方法几乎总是牺牲低级细节,这导致了相当大的精度下降.我们建议将这些空间细节和分类语义分开处理,以实现高精度和 ...
- CVPR 2020|图网络引导的实时语义分割网络搜索 (GAS)
论文链接:https://arxiv.org/abs/1909.06793 之后代码将会开源:https://github.com/L-Lighter/LightNet 作者:林培文*,孙鹏*,程光亮 ...
- 【语义分割系列】ICNET(实时语义分割)理论以及代码实现
git地址:https://github.com/Tramac/awesome-semantic-segmentation-pytorch 包括: FCN ENet PSPNet ICNet Deep ...
- BiSeNet:用于实时语义分割的双边分割网络-7min精简论文阅读系列-Leon
BiSeNet: Bilateral Segmentation Network for Real-time Semantic Segmentation BiSeNet:用于实时语义分割的双边分割网络 ...
- 实时语义分割算法大盘点
本文转载自计算机视觉工坊 语义分割论文 语义图像分割是计算机视觉中发展最快的领域之一,有着广泛的应用.在许多领域,如机器人和自动驾驶汽车,语义图像分割是至关重要的,因为它提供了必要的上下文,以采取行动 ...
- CFPNet:用于实时语义分割的通道特征金字塔
论文地址:CFPNet: Channel-wise Feature Pyramid for Real-Time Semantic Segmentation 代码地址: https://github.c ...
最新文章
- Linux Kernel TCP/IP Stack — L2 Layer — Linux VLAN device for 802.1.q(虚拟局域网)
- 7-5 顺序存储的二叉树的最近的公共祖先问题(25 分)
- sping拦截器配置顺序影响事务正常运行
- android 根据文件Uri获取图片url
- java instanceof 动态_Java关键字instanceof用法及实现策略
- 初步使用github,并上传下载文件
- python读取大文件csv_实现读取csv文件,文件里面是有限个百分数成绩(99.6、76.8等等...
- 2006最新版个人所得税计算器
- Kubernetes中的nodePort,targetPort,port的区别和意义(转)
- jupyter 设置主题Error:Could not find a version that satisfies the requirement jupyterthemes from version
- $.ajax.submit,jQuery中的AjaxSubmit使用讲解
- web项目中添加图标(unicode引用方式)
- 《自卑与超越》的读后感作文1600字
- 业绩爆发,押注“泛半导体”,TCL押对了吗?
- 贝叶斯决策类条件概率密度估计:最大似然和贝叶斯参数估计
- qt linux 程序设置字体,QT 程序更换字体方法之一
- BUUCTF_Misc题目题解记录
- 学日语小技巧 让Office Word效劳
- 华为防火墙虚拟系统实验
- HI3516DV300笔记(四)修改uboot环境变量