立体匹配 -- PSM-Net 网络模型代码剖析
- 只熟悉流程跑通代码不重要,重要的是理解网络的思想。
- GC-Net提出了3D-CNN编解码的形式做’cost volum ’ 后处理的过程,PSM-Net 加入图像金字塔的模块结合3D-CNN 输出图像视差图。
一. 特征提取模块
- 作者用 3层 33的小卷积核代替 77 的大卷积核,将图像降维1/2size. 虽然拥有同样大小的感受野,但深层的小 conv filter 显然有更少的参数,降低了计算成本。
def convbn(in_planes, out_planes, kernel_size, stride, pad, dilation):return nn.Sequential(nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=dilation if dilation > 1 else pad, dilation = dilation, bias=False),nn.BatchNorm2d(out_planes))
- 第一层 stride=2是为了降维输入图像 size,第二、三层是为了扩大感受野。
self.firstconv = nn.Sequential(convbn(3, 32, 3, 2, 1, 1),nn.ReLU(inplace=True),convbn(32, 32, 3, 1, 1, 1),nn.ReLU(inplace=True),convbn(32, 32, 3, 1, 1, 1),nn.ReLU(inplace=True))
conv1_x
、conv2_x
、conv3_x
、conv4_x
是提取二值特征的残差层。conv3_x
、conv4_x
使用了空洞卷积增大感受野,输出特征图的size是原图的1/4.
- 是不是又看到了大家熟悉的_make_layer ,别慌!一条条来!
self.layer1 = self._make_layer(BasicBlock, 32, 3, 1,1,1)self.layer2 = self._make_layer(BasicBlock, 64, 16, 2,1,1) self.layer3 = self._make_layer(BasicBlock, 128, 3, 1,1,1)self.layer4 = self._make_layer(BasicBlock, 128, 3, 1,1,2)
def _make_layer(self, block, planes, blocks, stride, pad, dilation):downsample = Noneif stride != 1 or self.inplanes != planes * block.expansion:downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes * block.expansion,kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(planes * block.expansion),)layers = []layers.append(block(self.inplanes, planes, stride, downsample, pad, dilation))self.inplanes = planes * block.expansionfor i in range(1, blocks):layers.append(block(self.inplanes, planes,1,None,pad,dilation))return nn.Sequential(*layers)
conv1_x
残差模块由3个 3332的卷积层构成
self.layer1 = self._make_layer(BasicBlock, 32, 3, 1,1,1)
conv2_x
残差模块由16个 3364的卷积层构成,加深二值特征的深度。32->64.stride=2 使 1/2feature size ->1/4 feature size.
self.layer2 = self._make_layer(BasicBlock, 64, 16, 2,1,1)
conv3_x
、conv4_x
文中的描述是这两层使用了空洞卷积扩大感受野。
注意:我看代码conv3_x
的dilation=1,以为是作者笔误,这里不是的,空洞卷积要连续使用,就是说dilation=n要连接 dilation=n-1…dilation=1,才能发挥不做pooling损失信息的情况下,加大了感受野,让每个卷积输出都包含较大范围的信息,且不改变feature_size的效果
self.layer3 = self._make_layer(BasicBlock, 128, 3, 1,1,1)self.layer4 = self._make_layer(BasicBlock, 128, 3, 1,1,2)
SPP特征金字塔模块
- SPP用以结合全局和局部上下文信息。
SPP模块使用自适应平均池化把特征压缩到4个尺度的平均池化:6464,32 32,1616,88上,并紧跟一个1*1的卷积层来减少特征维度,之后低维度的特征图通过双线性插值的方法进行上采样以恢复到原始图片的尺寸。不同级别的特征图都结合成最终的SPP特征图。 - 说明一下采用特征金字塔结构的好处:扩大感受野,使像素级特征扩展至多尺度区域特征。这样可以结合全局和局部信息,使cost volume更加完善。
- 前三个平均池化与第四个不太一样,先说1.2.3尺度的平均池化
self.branch1 = nn.Sequential(nn.AvgPool2d((64, 64), stride=(64,64)),convbn(128, 32, 1, 1, 0, 1),nn.ReLU(inplace=True))self.branch2 = nn.Sequential(nn.AvgPool2d((32, 32), stride=(32,32)),convbn(128, 32, 1, 1, 0, 1),nn.ReLU(inplace=True))self.branch3 = nn.Sequential(nn.AvgPool2d((16, 16), stride=(16,16)),convbn(128, 32, 1, 1, 0, 1),nn.ReLU(inplace=True))self.branch4 = nn.Sequential(nn.AvgPool2d((8, 8), stride=(8,8)),convbn(128, 32, 1, 1, 0, 1),nn.ReLU(inplace=True))self.lastconv = nn.Sequential(convbn(320, 128, 3, 1, 1, 1),nn.ReLU(inplace=True),nn.Conv2d(128, 32, kernel_size=1, padding=0, stride = 1, bias=False))
- 拿第一层 6464的平均池化举例。6464池化压缩特征,紧跟一个1* 1*32的conv2d卷积改变channel。
def convbn(in_planes, out_planes, kernel_size, stride, pad, dilation):return nn.Sequential(nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=dilation if dilation > 1 else pad, dilation = dilation, bias=False),nn.BatchNorm2d(out_planes))
self.branch1 = nn.Sequential(nn.AvgPool2d((64, 64), stride=(64,64)),convbn(128, 32, 1, 1, 0, 1),nn.ReLU(inplace=True))
- 这个是我打印出的特征维度,第一行是unary features的维度。第二行是 第一个自适应池化层降维后的输出。
CNN-out torch.Size([1, 128, 96, 312])
output_branch1-before torch.Size([1, 32, 1, 4])
- 紧跟其后就是一层上采样,采用双线性插值。
output_branch1 = F.upsample(output_branch1, (output_skip.size()[2],output_skip.size()[3]),mode='bilinear')
- 上采样 输出,因为最后要对不同池化层的输出特征图进行合并,所以要将每层池化的输出上采样到相同的(F,H,W)
output_branch1 torch.Size([1, 32, 96, 312])
- 将上采样以及原始输出按列拼凑,torch.cat(_,1) 这个“1”是按列加的意思。
output_feature = torch.cat((output_raw, output_skip, output_branch4, output_branch3, output_branch2, output_branch1), 1)
- 最后对320个filter size的输出进行降维,采用一层3* 3的conv2d和一层1* 1的conv2d(注意stride和padding)
self.lastconv = nn.Sequential(convbn(320, 128, 3, 1, 1, 1),nn.ReLU(inplace=True),nn.Conv2d(128, 32, kernel_size=1, padding=0, stride = 1, bias=False))
三.形成cost volum
- cost volume的维度是(1,64,48,96,312),64的前32个通道是左视图的,后32个通道是右视图的;48代表视差等级的维度,即视差为0-47px(设置最大视差为191px,后面会上采样)
cost = Variable(torch.FloatTensor(refimg_fea.size()[0], refimg_fea.size()[1]*2, self.maxdisp//4, refimg_fea.size()[2], refimg_fea.size()[3]).zero_()).cuda()for i in range(self.maxdisp//4):if i > 0 :cost[:, :refimg_fea.size()[1], i, :,i:] = refimg_fea[:,:,:,i:]cost[:, refimg_fea.size()[1]:, i, :,i:] = targetimg_fea[:,:,:,:-i]else:cost[:, :refimg_fea.size()[1], i, :,:] = refimg_feacost[:, refimg_fea.size()[1]:, i, :,:] = targetimg_fea
四.3D CNN生成视差图
- 只说编解码结构(stackhourglass),由多个重复的带有中间层监督的由精到粗再由粗到精的过程构成。这个堆叠的沙漏结构有三个主要的沙漏网络,每个都会生成一个视差图。这样三个沙漏结构就会由三个输出和三个损失。训练过程中,总的损失是由三个损失值的加权求和得到的。在测试过程中,最终的视差图是由三个输出中的最后一个得到的。
- 降维过程
- 不知道你们有没有注意到,每个模块的连接处,都有一个降维的部分。
- 而且都是采用两个卷积层。可以自己算下输出尺寸。
self.dres0 = nn.Sequential(convbn_3d(64, 32, 3, 1, 1),nn.ReLU(inplace=True),convbn_3d(32, 32, 3, 1, 1),nn.ReLU(inplace=True))self.dres1 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),nn.ReLU(inplace=True),convbn_3d(32, 32, 3, 1, 1))
- 注意结合结构图看,这是一个残差结构
cost0 = self.dres0(cost)cost0 = self.dres1(cost0) + cost0
- 编解码结构
先分析第一个hourglass结构。
根据结构图,1/4的size先压缩到1/8的size,
每次压缩都要有一层s=1的卷积链接。
(扩大感受野?为什么一直在扩大感受野,我觉得已经够了,hhhhh…)
self.conv1 = nn.Sequential(convbn_3d(inplanes, inplanes*2, kernel_size=3, stride=2, pad=1),nn.ReLU(inplace=True))self.conv2 = convbn_3d(inplanes*2, inplanes*2, kernel_size=3, stride=1, pad=1)
- 之后继续降维(1/8->1/16)
self.conv3 = nn.Sequential(convbn_3d(inplanes*2, inplanes*2, kernel_size=3, stride=2, pad=1),nn.ReLU(inplace=True))self.conv4 = nn.Sequential(convbn_3d(inplanes*2, inplanes*2, kernel_size=3, stride=1, pad=1),nn.ReLU(inplace=True))
- 和我上一篇博客GC-Net一样,采用编解码的形式,所以这里要上采样解码。这个转置卷积层,只有一层上采样,没有承接。
- 下采样提高速度和增大感受野的同时,也使细节丢失。将高分辨率的特征图与下采样层级联。高分辨率的图像使用转置卷积nn.ConvTranspose3d()得到。
- 1/16->1/8
self.conv5 = nn.Sequential(nn.ConvTranspose3d(inplanes*2, inplanes*2, kernel_size=3, padding=1, output_padding=1, stride=2,bias=False),nn.BatchNorm3d(inplanes*2)) #+conv2
- 1/8->1/4
self.conv6 = nn.Sequential(nn.ConvTranspose3d(inplanes*2, inplanes, kernel_size=3, padding=1, output_padding=1, stride=2,bias=False),nn.BatchNorm3d(inplanes)) #+x
- 仔细对比较那个结构图和这段代码,第一个编解码过程。只输入了一个cost volume.postsqu、presqu都是None.
out1, pre1, post1 = self.dres2(cost0, None, None)
- 注意得到post的过程,第五个卷积层和pre级联 (self.conv5(out)+pre)
def forward(self, x ,presqu, postsqu):out = self.conv1(x) #in:1/4 out:1/8pre = self.conv2(out) #in:1/8 out:1/8if postsqu is not None:pre = F.relu(pre + postsqu, inplace=True)else:pre = F.relu(pre, inplace=True)out = self.conv3(pre) #in:1/8 out:1/16out = self.conv4(out) #in:1/16 out:1/16if presqu is not None:post = F.relu(self.conv5(out)+presqu, inplace=True) #in:1/16 out:1/8else:post = F.relu(self.conv5(out)+pre, inplace=True) out = self.conv6(post) #in:1/8 out:1/4return out, pre, post
- 第一个编解码输出与经两层卷积的cost volume级联,级联高分辨率的特征图,避免丢失特征细节。
out1 = out1+cost0
- 第二个沙漏结构
- 注意到这根红线和绿线,红线是上一个沙漏结构传进来的
post
,绿线是pre
out2, pre2, post2 = self.dres3(out1, pre1, post1)
post
与第二层卷积层的输出级联。pre
与第四层上采样层的输出级联
if postsqu is not None:pre = F.relu(pre + postsqu, inplace=True)
if presqu is not None:post = F.relu(self.conv5(out)+presqu, inplace=True) #in:1/16 out:1/8
- 由于上采样和下采样操作和第一个结构重复,不多赘述。个人理解这么连接和GC-Net思路一样,担心下采样和上采样会丢失特征图细节信息,所以与初始特征图级联。
五.生成视差图
- 三个沙漏结构,三个输出,注意输出的结果间相互级联。
self.classif1 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),nn.ReLU(inplace=True),nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1,bias=False))
cost1 = self.classif1(out1)cost2 = self.classif2(out2) + cost1cost3 = self.classif3(out3) + cost2
- 视差回归的方式来估算连续的视差图,根据由
softmax
操作得到预测代价Cd来计算每一个视差值d的可能性。预测视差值d’由每一个视差值*其对应的可能性求和得到。
softmax - 这里有两种输出形式,一种是最后一个沙漏结构的输出最为最终结果(测试过程),一种是输出三个结构的预测结果计算loss(训练过程)。
1.测试过程 - 首先将特征图还原回原图大小。
cost3 = F.upsample(cost3, [self.maxdisp,left.size()[2],left.size()[3]], mode='trilinear')
- 压缩维度,这时候filter,通道已经被压缩成了1了,此时经过聚合的代价体的维度[1,192,384,1248]
cost3 = torch.squeeze(cost3,1)
- 在第二个维度(视差),进行softmax操作,将每个视察的概率计算出来。
pred3 = F.softmax(cost3,dim=1)
- 视差回归,利用的视差回归函数和GC-Net不一样。
pred3 = disparityregression(self.maxdisp)(pred3)
- 返回最后一个视差图
if self.training:return pred1, pred2, pred3else:return pred3
2.训练过程
- 因为训练过程需要计算loss,论文采用的是分别计算每个输出视差图的loss,最后进行加权和,加权和比例分配为(0.5,0.7,1)
- 可以看到论文采用
smooth_l1_loss
做loss函数,disp_true
是网站提供的标签 - 训练阶段返回
output1, output2, output3
分别是第一、二、 三沙漏结构输出的视差图。
if args.model == 'stackhourglass':output1, output2, output3 = model(imgL,imgR)output1 = torch.squeeze(output1,1)output2 = torch.squeeze(output2,1)output3 = torch.squeeze(output3,1)loss = 0.5*F.smooth_l1_loss(output1[mask], disp_true[mask], size_average=True) + 0.7*F.smooth_l1_loss(output2[mask], disp_true[mask], size_average=True) + F.smooth_l1_loss(output3[mask], disp_true[mask], size_average=True)
- 优化器
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))
立体匹配 -- PSM-Net 网络模型代码剖析相关推荐
- x264代码剖析(一):图文详解x264在Windows平台上的搭建
x264代码剖析(一):图文详解x264在Windows平台上的搭建 X264源码下载地址:http://ftp.videolan.org/pub/videolan/x264/ 平台:win7 PC. ...
- x264代码剖析(十四):核心算法之宏块编码函数x264_macroblock_encode()
x264代码剖析(十四):核心算法之宏块编码函数x264_macroblock_encode() 宏块编码函数x264_macroblock_encode()是完成变换与量化的主要函数,而x264_m ...
- HDFS集中式的缓存管理原理与代码剖析--转载
原文地址:http://yanbohappy.sinaapp.com/?p=468 Hadoop 2.3.0已经发布了,其中最大的亮点就是集中式的缓存管理(HDFS centralized cache ...
- x264代码剖析(四):vs2010编译x264错误集锦
x264代码剖析(四):vs2010编译x264错误集锦 支持VC++平台的x264的最新版本是x264-20091006,接下来就以该版本为例分析编译运行x264过程中遇到的问题以及解决办法. 1. ...
- 微信跳一跳刷分代码剖析
转载地址:http://blog.csdn.net/u013780605/article/details/78945239?ref=myrecommend 感谢学霸提供了这一途径,感谢原作者无私奉献. ...
- K8S网络模型原理剖析和实践-杜军-专题视频课程
K8S网络模型原理剖析和实践-38人已学习 课程介绍 由华为云高级工程师杜军讲解,将为您介绍Kubernetes的网络模型与实现机制,让您对Kubernetes网络有一个系统的理解和认 ...
- IDDPM原理和代码剖析
前言 Improved Denoising Diffusion Probabilistic Models(IDDPM) 是上一篇 Denoising Diffusion Probabilistic M ...
- windows下tomcat8启动脚本代码剖析--catalina.bat
Windows下,Tomcat可以以服务形式启动.停止,也可以执行脚本启动(startup.bat).停止(shutdown.bat).执行startup.bat时会调用catalina.bat,ca ...
- x264代码剖析(二):如何编译运行x264以及x264代码基本框架
x264代码剖析(二):如何编译运行x264以及x264代码基本框架 x264工程在x265出现之前一直在更新,但是自x264-20091007(含)不再支持VC++平台,也就是说支持VC++平台的x ...
- x264代码剖析(九):x264_encoder_encode()函数之x264_slice's'_write()函数
x264代码剖析(九):x264_encoder_encode()函数之x264_slice's'_write()函数 x264_encoder_encode()函数的核心函数就是x264_slice ...
最新文章
- 软件开发的比喻:园艺
- 配置java环境变量
- PooledByteBuf源码分析
- pom文件报错_maven-resources-plugin修改了我的文件
- 美国国会针对中国的网络间谍行动展开辩论
- while read line的问题
- 2021年广西高考成绩查询方法,2021年广西高考成绩查询网站查分网址:https://www.gxeea.cn/...
- char、varchar、nchar、nvarchar的区别
- 新计算机无法 盘启动不了,U盘无法被电脑识别导致制作U盘启动盘失败怎么办?...
- dj鲜生-05-配置-静态目录-模板目录-后台语言时区
- margin系列之负值
- [APIO2017]商旅——分数优化+floyd+SPFA判负环+二分答案
- 股市常胜将军都懂得适时休息
- 加密 数字_数字卢布不会具有BTC这样的加密资产的优势
- 单线程与线程池的性能对比
- 知识图谱-命名实体-关系-免费标注工具-快速打标签-Python3
- WindwosAndroid浏览器内核版本检测
- 录屏状态监听之防录屏 - iOS
- 关于JavaScript的一些使用心得
- Python爬虫:调用百度翻译接口实现中英翻译功能