Deeplab v3+的结构代码简要分析--Pytorch 版
转载自:https://www.cnblogs.com/ywheunji/p/10479019.html。侵删
添加了解码模块来重构精确的图像物体边界。对比如图
deeplab v3+采用了与deeplab v3类似的多尺度带洞卷积结构ASPP,然后通过上采样,以及与不同卷积层相拼接,最终经过卷积以及上采样得到结果。
deeplab v3:
基于提出的编码-解码结构,可以任意通过控制 atrous convolution 来输出编码特征的分辨率,来平衡精度和运行时间(已有编码-解码结构不具有该能力.).
可以用来挖掘不同尺度的上下文信息
PSPNet 对不同尺度的网络进行池化处理,处理多尺度的上下文内容信息
deeplab v3+以resnet101为backbone
1 import math2 import torch3 import torch.nn as nn4 import torch.nn.functional as F5 import torch.utils.model_zoo as model_zoo6 from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d7 8 BatchNorm2d = SynchronizedBatchNorm2d9 10 class Bottleneck(nn.Module):#'resnet网络的基本框架’11 expansion = 412 13 def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):14 super(Bottleneck, self).__init__()15 self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)16 self.bn1 = BatchNorm2d(planes)17 self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,18 dilation=dilation, padding=dilation, bias=False)19 self.bn2 = BatchNorm2d(planes)20 self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)21 self.bn3 = BatchNorm2d(planes * 4)22 self.relu = nn.ReLU(inplace=True)23 self.downsample = downsample24 self.stride = stride25 self.dilation = dilation26 27 def forward(self, x):28 residual = x29 30 out = self.conv1(x)31 out = self.bn1(out)32 out = self.relu(out)33 34 out = self.conv2(out)35 out = self.bn2(out)36 out = self.relu(out)37 38 out = self.conv3(out)39 out = self.bn3(out)40 41 if self.downsample is not None:42 residual = self.downsample(x)43 44 out += residual45 out = self.relu(out)46 47 return out48 49 class ResNet(nn.Module):50 #renet网络的构成部分51 def __init__(self, nInputChannels, block, layers, os=16, pretrained=False):52 self.inplanes = 6453 super(ResNet, self).__init__()54 if os == 16:55 strides = [1, 2, 2, 1]56 dilations = [1, 1, 1, 2]57 blocks = [1, 2, 4]58 elif os == 8:59 strides = [1, 2, 1, 1]60 dilations = [1, 1, 2, 2]61 blocks = [1, 2, 1]62 else:63 raise NotImplementedError64 65 # Modules66 self.conv1 = nn.Conv2d(nInputChannels, 64, kernel_size=7, stride=2, padding=3,67 bias=False)68 self.bn1 = BatchNorm2d(64)69 self.relu = nn.ReLU(inplace=True)70 self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)71 72 self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0])73 self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1])74 self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2])75 self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3])76 77 self._init_weight()78 79 if pretrained:80 self._load_pretrained_model()81 82 def _make_layer(self, block, planes, blocks, stride=1, dilation=1):83 downsample = None84 if stride != 1 or self.inplanes != planes * block.expansion:85 downsample = nn.Sequential(86 nn.Conv2d(self.inplanes, planes * block.expansion,87 kernel_size=1, stride=stride, bias=False),88 BatchNorm2d(planes * block.expansion),89 )90 91 layers = []92 layers.append(block(self.inplanes, planes, stride, dilation, downsample))93 self.inplanes = planes * block.expansion94 for i in range(1, blocks):95 layers.append(block(self.inplanes, planes))96 97 return nn.Sequential(*layers)98 99 def _make_MG_unit(self, block, planes, blocks=[1, 2, 4], stride=1, dilation=1):
100 downsample = None
101 if stride != 1 or self.inplanes != planes * block.expansion:
102 downsample = nn.Sequential(
103 nn.Conv2d(self.inplanes, planes * block.expansion,
104 kernel_size=1, stride=stride, bias=False),
105 BatchNorm2d(planes * block.expansion),
106 )
107
108 layers = []
109 layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation, downsample=downsample))
110 self.inplanes = planes * block.expansion
111 for i in range(1, len(blocks)):
112 layers.append(block(self.inplanes, planes, stride=1, dilation=blocks[i]*dilation))
113
114 return nn.Sequential(*layers)
115
116 def forward(self, input):
117 x = self.conv1(input)
118 x = self.bn1(x)
119 x = self.relu(x)
120 x = self.maxpool(x)
121
122 x = self.layer1(x)
123 low_level_feat = x
124 x = self.layer2(x)
125 x = self.layer3(x)
126 x = self.layer4(x)
127 return x, low_level_feat
128
129 def _init_weight(self):
130 for m in self.modules():
131 if isinstance(m, nn.Conv2d):
132 n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
133 m.weight.data.normal_(0, math.sqrt(2. / n))
134 elif isinstance(m, BatchNorm2d):
135 m.weight.data.fill_(1)
136 m.bias.data.zero_()
137
138 def _load_pretrained_model(self):
139 pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth')
140 model_dict = {}
141 state_dict = self.state_dict()
142 for k, v in pretrain_dict.items():
143 if k in state_dict:
144 model_dict[k] = v
145 state_dict.update(model_dict)
146 self.load_state_dict(state_dict)
147
148 def ResNet101(nInputChannels=3, os=16, pretrained=False):
149 model = ResNet(nInputChannels, Bottleneck, [3, 4, 23, 3], os, pretrained=pretrained)
150 return model
151
152
153 class ASPP_module(nn.Module):#ASpp模块的组成
154 def __init__(self, inplanes, planes, dilation):
155 super(ASPP_module, self).__init__()
156 if dilation == 1:
157 kernel_size = 1
158 padding = 0
159 else:
160 kernel_size = 3
161 padding = dilation
162 self.atrous_convolution = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
163 stride=1, padding=padding, dilation=dilation, bias=False)
164 self.bn = BatchNorm2d(planes)
165 self.relu = nn.ReLU()
166
167 self._init_weight()
168
169 def forward(self, x):
170 x = self.atrous_convolution(x)
171 x = self.bn(x)
172
173 return self.relu(x)
174
175 def _init_weight(self):
176 for m in self.modules():
177 if isinstance(m, nn.Conv2d):
178 n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
179 m.weight.data.normal_(0, math.sqrt(2. / n))
180 elif isinstance(m, BatchNorm2d):
181 m.weight.data.fill_(1)
182 m.bias.data.zero_()
183
184
185 class DeepLabv3_plus(nn.Module):#正式开始deeplabv3+的结构组成
186 def __init__(self, nInputChannels=3, n_classes=21, os=16, pretrained=False, freeze_bn=False, _print=True):
187 if _print:
188 print("Constructing DeepLabv3+ model...")
189 print("Backbone: Resnet-101")
190 print("Number of classes: {}".format(n_classes))
191 print("Output stride: {}".format(os))
192 print("Number of Input Channels: {}".format(nInputChannels))
193 super(DeepLabv3_plus, self).__init__()
194
195 # Atrous Conv 首先获得从resnet101中提取的features map
196 self.resnet_features = ResNet101(nInputChannels, os, pretrained=pretrained)
197
198 # ASPP,挑选参数
199 if os == 16:
200 dilations = [1, 6, 12, 18]
201 elif os == 8:
202 dilations = [1, 12, 24, 36]
203 else:
204 raise NotImplementedError
205 #四个不同带洞卷积的设置,获取不同感受野
206 self.aspp1 = ASPP_module(2048, 256, dilation=dilations[0])
207 self.aspp2 = ASPP_module(2048, 256, dilation=dilations[1])
208 self.aspp3 = ASPP_module(2048, 256, dilation=dilations[2])
209 self.aspp4 = ASPP_module(2048, 256, dilation=dilations[3])
210
211 self.relu = nn.ReLU()
212 #全局平均池化层的设置
213 self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
214 nn.Conv2d(2048, 256, 1, stride=1, bias=False),
215 BatchNorm2d(256),
216 nn.ReLU())
217
218 self.conv1 = nn.Conv2d(1280, 256, 1, bias=False)
219 self.bn1 = BatchNorm2d(256)
220
221 # adopt [1x1, 48] for channel reduction.
222 self.conv2 = nn.Conv2d(256, 48, 1, bias=False)
223 self.bn2 = BatchNorm2d(48)
224 #结构图中的解码部分的最后一个3*3的卷积块
225 self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False),
226 BatchNorm2d(256),
227 nn.ReLU(),
228 nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
229 BatchNorm2d(256),
230 nn.ReLU(),
231 nn.Conv2d(256, n_classes, kernel_size=1, stride=1))
232 if freeze_bn:
233 self._freeze_bn()
234 #前向传播
235 def forward(self, input):
236 x, low_level_features = self.resnet_features(input)
237 x1 = self.aspp1(x)
238 x2 = self.aspp2(x)
239 x3 = self.aspp3(x)
240 x4 = self.aspp4(x)
241 x5 = self.global_avg_pool(x)
242 x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
243 #把四个ASPP模块以及全局池化层拼接起来
244 x = torch.cat((x1, x2, x3, x4, x5), dim=1)
245 #上采样
246 x = self.conv1(x)
247 x = self.bn1(x)
248 x = self.relu(x)
249 x = F.upsample(x, size=(int(math.ceil(input.size()[-2]/4)),
250 int(math.ceil(input.size()[-1]/4))), mode='bilinear', align_corners=True)
251
252 low_level_features = self.conv2(low_level_features)
253 low_level_features = self.bn2(low_level_features)
254 low_level_features = self.relu(low_level_features)
255
256 #拼接低层次的特征,然后再通过插值获取原图大小的结果
257 x = torch.cat((x, low_level_features), dim=1)
258 x = self.last_conv(x)
259 x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True)
260
261 return x
262
263 def _freeze_bn(self):
264 for m in self.modules():
265 if isinstance(m, BatchNorm2d):
266 m.eval()
267
268 def _init_weight(self):
269 for m in self.modules():
270 if isinstance(m, nn.Conv2d):
271 n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
272 m.weight.data.normal_(0, math.sqrt(2. / n))
273 elif isinstance(m, BatchNorm2d):
274 m.weight.data.fill_(1)
275 m.bias.data.zero_()
276
277 def get_1x_lr_params(model):
278 """
279 This generator returns all the parameters of the net except for
280 the last classification layer. Note that for each batchnorm layer,
281 requires_grad is set to False in deeplab_resnet.py, therefore this function does not return
282 any batchnorm parameter
283 """
284 b = [model.resnet_features]
285 for i in range(len(b)):
286 for k in b[i].parameters():
287 if k.requires_grad:
288 yield k
289
290
291 def get_10x_lr_params(model):
292 """
293 This generator returns all the parameters for the last layer of the net,
294 which does the classification of pixel into classes
295 """
296 b = [model.aspp1, model.aspp2, model.aspp3, model.aspp4, model.conv1, model.conv2, model.last_conv]
297 for j in range(len(b)):
298 for k in b[j].parameters():
299 if k.requires_grad:
300 yield k
301
302
303 if __name__ == "__main__":
304 model = DeepLabv3_plus(nInputChannels=3, n_classes=21, os=16, pretrained=True, _print=True)
305 model.eval()
306 image = torch.randn(1, 3, 512, 512)
307 with torch.no_grad():
308 output = model.forward(image)
309 print(output.size())
Deeplab v3+的结构代码简要分析--Pytorch 版相关推荐
- 【u-boot】uboot代码简要分析 (u-boot 移植)
uboot代码简要分析 (u-boot 移植) 2012-12-19 22:46:04 [转] 先来看看源码目录结构,再按照代码的执行顺序简单地分析源码 1.U-boot源码整体框架 源码解压以后,我 ...
- 彩色星球图片生成3:代码改进(pytorch版)
彩色星球图片生成3:代码改进(pytorch版) 1. 修改 1.1 预处理缩放 1.2 随机翻转 1.3 修改全局判别器 1.4 修改进度打印 2. 效果 3. 总结 上一集: 彩色星球图片生成2: ...
- NLP入门之——Word2Vec词向量Skip-Gram模型代码实现(Pytorch版)
代码地址:https://github.com/liangyming/NLP-Word2Vec.git 1. 什么是Word2Vec Word2vec是Google开源的将词表征为实数值向量的高效工具 ...
- deeplab v3+---Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation
一.摘要: 1.spp是什么? 问题:分割 我们提出了什么: 1.deeplab v3+ 在deeplab v3的基础上加了一个简单的decoder模块来改善分割结果,尤其是对于边界区域 2.我们采用 ...
- [深度学习从入门到女装]DeepLab v3
DeepLab v3 论文地址:Rethinking Atrous Convolution for Semantic Image Segmentation 1.相比于DeepLab v2,有一点改进就 ...
- 彩色星球图片生成4:转置卷积+插值缩放+卷积收缩(pytorch版)
彩色星球图片生成4:转置卷积层+插值缩放+卷积收缩(pytorch版) 1. 改进方面 1.1 优化器与优化步长 1.2 交叉熵损失函数 1.3 Patch判别器 1.4 输入分辨率 1.5 转置卷积 ...
- PyTorch多卡分布式训练:DistributedDataParallel (DDP) 简要分析
©作者 | 伟大是熬出来的 单位 | 同济大学 研究方向 | 机器阅读理解 前言 因为课题组发的卡还没有下来,先向导师问了实验室的两张卡借用.之前都是单卡训练模型,正好在这个机会实践以下单机多卡训练模 ...
- 附代码 Deeplab V3
Rethinking Atrous Convolution for Semantic Image Segmentation 论文解读 参考链接:https://zhuanlan.zhihu.com/p ...
- 【代码分析】Pytorch版YOLO V4代码分析
YOLO V4出来也几天了,论文大致看了下,然后看到大量的优秀者实现了各个版本的YOLOV4了. Yolo v4 论文: https://arxiv.org/abs/2004.10934 AB大神Da ...
最新文章
- 危机四伏的千亿级金融放贷市场,我们能做什么?
- 造成java.io.IOException: Stream Closed异常的代码
- Windows 技巧篇-鼠标指针安装方法,漂亮的鼠标指针推荐
- Codeforces Round #606 (Div. 2, based on Technocup 2020 Elimination Round 4) 构造
- linux命令deploy_Linux deploy 使用教程
- java dumpheap_java程序性能分析之thread dump和heap dump
- 在HermesJMS中创建ActiveMQ Session
- 有没有妈妈生了孩子一点不像自己的,觉得亏吗?
- 电切镜行业调研报告 - 市场现状分析与发展前景预测(2021-2027年)
- Unity 5.x---00使用重力
- USACO 1.5 Prime Palindromes
- Python爬虫,超简单地实现一键提取阴阳师原画
- 毕业生Markdown简历模板
- 中国最美的100首古代情诗
- 恋恋山城 Jean de Florette (1986) 男人的野心 / 弗洛莱特的若望 / 让·德·弗罗莱特 / 水源 下一部 甘泉,玛侬...
- 方舟linux服务器更新,方舟怎么更新服务器版本 | 手游网游页游攻略大全
- win10 oracle11g 乱码,小编教你解决win10系统出现汉字乱码的处理办法
- python编程比赛初赛 组成最小罗马数字_leetcode 题解 12python3@ 通过使用罗马数字的最单元位来构造数组 + 构造数字算法...
- 3dMax 光标丢失,无法正常显示
- mac 时间机器备份到windows共享文件遇到的问题及解决记录
热门文章
- N - Wires(dfs 图论 离散化)
- 手把手教你IDEA+SpringBoot+MyBatis+MySql实现动态登录与注册
- 在安装软件时,出现:系统管理员设置了系统策略,禁止进行此项安装 怎么办
- Rancher2.x--stable版本环境搭建
- 基于MOS缓启动电路笔记
- python 动态执行 内存变化_详解Pytorch显存动态分配规律探索
- Access关键词大全
- 我的世界源代码python_pygame2D我的世界简易源代码
- c语言 15分钟试讲,15分钟教师试讲如何脱颖而出?亲身总结
- python运行时黑屏什么原因_pygame 程序未报错却黑屏无法显示 请问有大佬知道解决方法吗...