转载自:https://www.cnblogs.com/ywheunji/p/10479019.html。侵删

添加了解码模块来重构精确的图像物体边界。对比如图

deeplab v3+采用了与deeplab v3类似的多尺度带洞卷积结构ASPP,然后通过上采样,以及与不同卷积层相拼接,最终经过卷积以及上采样得到结果。

deeplab v3:

基于提出的编码-解码结构,可以任意通过控制 atrous convolution 来输出编码特征的分辨率,来平衡精度和运行时间(已有编码-解码结构不具有该能力.).

可以用来挖掘不同尺度的上下文信息

PSPNet 对不同尺度的网络进行池化处理,处理多尺度的上下文内容信息

deeplab v3+以resnet101为backbone

  1 import math2 import torch3 import torch.nn as nn4 import torch.nn.functional as F5 import torch.utils.model_zoo as model_zoo6 from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d7 8 BatchNorm2d = SynchronizedBatchNorm2d9 10 class Bottleneck(nn.Module):#'resnet网络的基本框架’11     expansion = 412 13     def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):14         super(Bottleneck, self).__init__()15         self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)16         self.bn1 = BatchNorm2d(planes)17         self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,18                                dilation=dilation, padding=dilation, bias=False)19         self.bn2 = BatchNorm2d(planes)20         self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)21         self.bn3 = BatchNorm2d(planes * 4)22         self.relu = nn.ReLU(inplace=True)23         self.downsample = downsample24         self.stride = stride25         self.dilation = dilation26 27     def forward(self, x):28         residual = x29 30         out = self.conv1(x)31         out = self.bn1(out)32         out = self.relu(out)33 34         out = self.conv2(out)35         out = self.bn2(out)36         out = self.relu(out)37 38         out = self.conv3(out)39         out = self.bn3(out)40 41         if self.downsample is not None:42             residual = self.downsample(x)43 44         out += residual45         out = self.relu(out)46 47         return out48 49 class ResNet(nn.Module):50   #renet网络的构成部分51     def __init__(self, nInputChannels, block, layers, os=16, pretrained=False):52         self.inplanes = 6453         super(ResNet, self).__init__()54         if os == 16:55             strides = [1, 2, 2, 1]56             dilations = [1, 1, 1, 2]57             blocks = [1, 2, 4]58         elif os == 8:59             strides = [1, 2, 1, 1]60             dilations = [1, 1, 2, 2]61             blocks = [1, 2, 1]62         else:63             raise NotImplementedError64 65         # Modules66         self.conv1 = nn.Conv2d(nInputChannels, 64, kernel_size=7, stride=2, padding=3,67                                 bias=False)68         self.bn1 = BatchNorm2d(64)69         self.relu = nn.ReLU(inplace=True)70         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)71 72         self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0])73         self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1])74         self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2])75         self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3])76 77         self._init_weight()78 79         if pretrained:80             self._load_pretrained_model()81 82     def _make_layer(self, block, planes, blocks, stride=1, dilation=1):83         downsample = None84         if stride != 1 or self.inplanes != planes * block.expansion:85             downsample = nn.Sequential(86                 nn.Conv2d(self.inplanes, planes * block.expansion,87                           kernel_size=1, stride=stride, bias=False),88                 BatchNorm2d(planes * block.expansion),89             )90 91         layers = []92         layers.append(block(self.inplanes, planes, stride, dilation, downsample))93         self.inplanes = planes * block.expansion94         for i in range(1, blocks):95             layers.append(block(self.inplanes, planes))96 97         return nn.Sequential(*layers)98 99     def _make_MG_unit(self, block, planes, blocks=[1, 2, 4], stride=1, dilation=1):
100         downsample = None
101         if stride != 1 or self.inplanes != planes * block.expansion:
102             downsample = nn.Sequential(
103                 nn.Conv2d(self.inplanes, planes * block.expansion,
104                           kernel_size=1, stride=stride, bias=False),
105                 BatchNorm2d(planes * block.expansion),
106             )
107
108         layers = []
109         layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation, downsample=downsample))
110         self.inplanes = planes * block.expansion
111         for i in range(1, len(blocks)):
112             layers.append(block(self.inplanes, planes, stride=1, dilation=blocks[i]*dilation))
113
114         return nn.Sequential(*layers)
115
116     def forward(self, input):
117         x = self.conv1(input)
118         x = self.bn1(x)
119         x = self.relu(x)
120         x = self.maxpool(x)
121
122         x = self.layer1(x)
123         low_level_feat = x
124         x = self.layer2(x)
125         x = self.layer3(x)
126         x = self.layer4(x)
127         return x, low_level_feat
128
129     def _init_weight(self):
130         for m in self.modules():
131             if isinstance(m, nn.Conv2d):
132                 n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
133                 m.weight.data.normal_(0, math.sqrt(2. / n))
134             elif isinstance(m, BatchNorm2d):
135                 m.weight.data.fill_(1)
136                 m.bias.data.zero_()
137
138     def _load_pretrained_model(self):
139         pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth')
140         model_dict = {}
141         state_dict = self.state_dict()
142         for k, v in pretrain_dict.items():
143             if k in state_dict:
144                 model_dict[k] = v
145         state_dict.update(model_dict)
146         self.load_state_dict(state_dict)
147
148 def ResNet101(nInputChannels=3, os=16, pretrained=False):
149     model = ResNet(nInputChannels, Bottleneck, [3, 4, 23, 3], os, pretrained=pretrained)
150     return model
151
152
153 class ASPP_module(nn.Module):#ASpp模块的组成
154     def __init__(self, inplanes, planes, dilation):
155         super(ASPP_module, self).__init__()
156         if dilation == 1:
157             kernel_size = 1
158             padding = 0
159         else:
160             kernel_size = 3
161             padding = dilation
162         self.atrous_convolution = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
163                                             stride=1, padding=padding, dilation=dilation, bias=False)
164         self.bn = BatchNorm2d(planes)
165         self.relu = nn.ReLU()
166
167         self._init_weight()
168
169     def forward(self, x):
170         x = self.atrous_convolution(x)
171         x = self.bn(x)
172
173         return self.relu(x)
174
175     def _init_weight(self):
176         for m in self.modules():
177             if isinstance(m, nn.Conv2d):
178                 n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
179                 m.weight.data.normal_(0, math.sqrt(2. / n))
180             elif isinstance(m, BatchNorm2d):
181                 m.weight.data.fill_(1)
182                 m.bias.data.zero_()
183
184
185 class DeepLabv3_plus(nn.Module):#正式开始deeplabv3+的结构组成
186     def __init__(self, nInputChannels=3, n_classes=21, os=16, pretrained=False, freeze_bn=False, _print=True):
187         if _print:
188             print("Constructing DeepLabv3+ model...")
189             print("Backbone: Resnet-101")
190             print("Number of classes: {}".format(n_classes))
191             print("Output stride: {}".format(os))
192             print("Number of Input Channels: {}".format(nInputChannels))
193         super(DeepLabv3_plus, self).__init__()
194
195         # Atrous Conv  首先获得从resnet101中提取的features map
196         self.resnet_features = ResNet101(nInputChannels, os, pretrained=pretrained)
197
198         # ASPP,挑选参数
199         if os == 16:
200             dilations = [1, 6, 12, 18]
201         elif os == 8:
202             dilations = [1, 12, 24, 36]
203         else:
204             raise NotImplementedError
205     #四个不同带洞卷积的设置,获取不同感受野
206         self.aspp1 = ASPP_module(2048, 256, dilation=dilations[0])
207         self.aspp2 = ASPP_module(2048, 256, dilation=dilations[1])
208         self.aspp3 = ASPP_module(2048, 256, dilation=dilations[2])
209         self.aspp4 = ASPP_module(2048, 256, dilation=dilations[3])
210
211         self.relu = nn.ReLU()
212     #全局平均池化层的设置
213         self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
214                                              nn.Conv2d(2048, 256, 1, stride=1, bias=False),
215                                              BatchNorm2d(256),
216                                              nn.ReLU())
217
218         self.conv1 = nn.Conv2d(1280, 256, 1, bias=False)
219         self.bn1 = BatchNorm2d(256)
220
221         # adopt [1x1, 48] for channel reduction.
222         self.conv2 = nn.Conv2d(256, 48, 1, bias=False)
223         self.bn2 = BatchNorm2d(48)
224     #结构图中的解码部分的最后一个3*3的卷积块
225         self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False),
226                                        BatchNorm2d(256),
227                                        nn.ReLU(),
228                                        nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
229                                        BatchNorm2d(256),
230                                        nn.ReLU(),
231                                        nn.Conv2d(256, n_classes, kernel_size=1, stride=1))
232         if freeze_bn:
233             self._freeze_bn()
234   #前向传播
235     def forward(self, input):
236         x, low_level_features = self.resnet_features(input)
237         x1 = self.aspp1(x)
238         x2 = self.aspp2(x)
239         x3 = self.aspp3(x)
240         x4 = self.aspp4(x)
241         x5 = self.global_avg_pool(x)
242         x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
243     #把四个ASPP模块以及全局池化层拼接起来
244         x = torch.cat((x1, x2, x3, x4, x5), dim=1)
245     #上采样
246         x = self.conv1(x)
247         x = self.bn1(x)
248         x = self.relu(x)
249         x = F.upsample(x, size=(int(math.ceil(input.size()[-2]/4)),
250                                 int(math.ceil(input.size()[-1]/4))), mode='bilinear', align_corners=True)
251
252         low_level_features = self.conv2(low_level_features)
253         low_level_features = self.bn2(low_level_features)
254         low_level_features = self.relu(low_level_features)
255
256      #拼接低层次的特征,然后再通过插值获取原图大小的结果
257         x = torch.cat((x, low_level_features), dim=1)
258         x = self.last_conv(x)
259         x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True)
260
261         return x
262
263     def _freeze_bn(self):
264         for m in self.modules():
265             if isinstance(m, BatchNorm2d):
266                 m.eval()
267
268     def _init_weight(self):
269         for m in self.modules():
270             if isinstance(m, nn.Conv2d):
271                 n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
272                 m.weight.data.normal_(0, math.sqrt(2. / n))
273             elif isinstance(m, BatchNorm2d):
274                 m.weight.data.fill_(1)
275                 m.bias.data.zero_()
276
277 def get_1x_lr_params(model):
278     """
279     This generator returns all the parameters of the net except for
280     the last classification layer. Note that for each batchnorm layer,
281     requires_grad is set to False in deeplab_resnet.py, therefore this function does not return
282     any batchnorm parameter
283     """
284     b = [model.resnet_features]
285     for i in range(len(b)):
286         for k in b[i].parameters():
287             if k.requires_grad:
288                 yield k
289
290
291 def get_10x_lr_params(model):
292     """
293     This generator returns all the parameters for the last layer of the net,
294     which does the classification of pixel into classes
295     """
296     b = [model.aspp1, model.aspp2, model.aspp3, model.aspp4, model.conv1, model.conv2, model.last_conv]
297     for j in range(len(b)):
298         for k in b[j].parameters():
299             if k.requires_grad:
300                 yield k
301
302
303 if __name__ == "__main__":
304     model = DeepLabv3_plus(nInputChannels=3, n_classes=21, os=16, pretrained=True, _print=True)
305     model.eval()
306     image = torch.randn(1, 3, 512, 512)
307     with torch.no_grad():
308         output = model.forward(image)
309     print(output.size())

Deeplab v3+的结构代码简要分析--Pytorch 版相关推荐

  1. 【u-boot】uboot代码简要分析 (u-boot 移植)

    uboot代码简要分析 (u-boot 移植) 2012-12-19 22:46:04 [转] 先来看看源码目录结构,再按照代码的执行顺序简单地分析源码 1.U-boot源码整体框架 源码解压以后,我 ...

  2. 彩色星球图片生成3:代码改进(pytorch版)

    彩色星球图片生成3:代码改进(pytorch版) 1. 修改 1.1 预处理缩放 1.2 随机翻转 1.3 修改全局判别器 1.4 修改进度打印 2. 效果 3. 总结 上一集: 彩色星球图片生成2: ...

  3. NLP入门之——Word2Vec词向量Skip-Gram模型代码实现(Pytorch版)

    代码地址:https://github.com/liangyming/NLP-Word2Vec.git 1. 什么是Word2Vec Word2vec是Google开源的将词表征为实数值向量的高效工具 ...

  4. deeplab v3+---Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation

    一.摘要: 1.spp是什么? 问题:分割 我们提出了什么: 1.deeplab v3+ 在deeplab v3的基础上加了一个简单的decoder模块来改善分割结果,尤其是对于边界区域 2.我们采用 ...

  5. [深度学习从入门到女装]DeepLab v3

    DeepLab v3 论文地址:Rethinking Atrous Convolution for Semantic Image Segmentation 1.相比于DeepLab v2,有一点改进就 ...

  6. 彩色星球图片生成4:转置卷积+插值缩放+卷积收缩(pytorch版)

    彩色星球图片生成4:转置卷积层+插值缩放+卷积收缩(pytorch版) 1. 改进方面 1.1 优化器与优化步长 1.2 交叉熵损失函数 1.3 Patch判别器 1.4 输入分辨率 1.5 转置卷积 ...

  7. PyTorch多卡分布式训练:DistributedDataParallel (DDP) 简要分析

    ©作者 | 伟大是熬出来的 单位 | 同济大学 研究方向 | 机器阅读理解 前言 因为课题组发的卡还没有下来,先向导师问了实验室的两张卡借用.之前都是单卡训练模型,正好在这个机会实践以下单机多卡训练模 ...

  8. 附代码 Deeplab V3

    Rethinking Atrous Convolution for Semantic Image Segmentation 论文解读 参考链接:https://zhuanlan.zhihu.com/p ...

  9. 【代码分析】Pytorch版YOLO V4代码分析

    YOLO V4出来也几天了,论文大致看了下,然后看到大量的优秀者实现了各个版本的YOLOV4了. Yolo v4 论文: https://arxiv.org/abs/2004.10934 AB大神Da ...

最新文章

  1. 危机四伏的千亿级金融放贷市场,我们能做什么?
  2. 造成java.io.IOException: Stream Closed异常的代码
  3. Windows 技巧篇-鼠标指针安装方法,漂亮的鼠标指针推荐
  4. Codeforces Round #606 (Div. 2, based on Technocup 2020 Elimination Round 4) 构造
  5. linux命令deploy_Linux deploy 使用教程
  6. java dumpheap_java程序性能分析之thread dump和heap dump
  7. 在HermesJMS中创建ActiveMQ Session
  8. 有没有妈妈生了孩子一点不像自己的,觉得亏吗?
  9. 电切镜行业调研报告 - 市场现状分析与发展前景预测(2021-2027年)
  10. Unity 5.x---00使用重力
  11. USACO 1.5 Prime Palindromes
  12. Python爬虫,超简单地实现一键提取阴阳师原画
  13. 毕业生Markdown简历模板
  14. 中国最美的100首古代情诗
  15. 恋恋山城 Jean de Florette (1986) 男人的野心 / 弗洛莱特的若望 / 让·德·弗罗莱特 / 水源 下一部 甘泉,玛侬...
  16. 方舟linux服务器更新,方舟怎么更新服务器版本 | 手游网游页游攻略大全
  17. win10 oracle11g 乱码,小编教你解决win10系统出现汉字乱码的处理办法
  18. python编程比赛初赛 组成最小罗马数字_leetcode 题解 12python3@ 通过使用罗马数字的最单元位来构造数组 + 构造数字算法...
  19. 3dMax 光标丢失,无法正常显示
  20. mac 时间机器备份到windows共享文件遇到的问题及解决记录

热门文章

  1. N - Wires(dfs 图论 离散化)
  2. 手把手教你IDEA+SpringBoot+MyBatis+MySql实现动态登录与注册
  3. 在安装软件时,出现:系统管理员设置了系统策略,禁止进行此项安装 怎么办
  4. Rancher2.x--stable版本环境搭建
  5. 基于MOS缓启动电路笔记
  6. python 动态执行 内存变化_详解Pytorch显存动态分配规律探索
  7. Access关键词大全
  8. 我的世界源代码python_pygame2D我的世界简易源代码
  9. c语言 15分钟试讲,15分钟教师试讲如何脱颖而出?亲身总结
  10. python运行时黑屏什么原因_pygame 程序未报错却黑屏无法显示 请问有大佬知道解决方法吗...