MobileNetV3代码和网络
MobileNetV3实现
参考博客:
https://zhuanlan.zhihu.com/p/323346888
https://blog.csdn.net/Chunfengyanyulove/article/details/91358187
来自paddelocr中的MobileNetV3,实现网络的类如下。
ConvBNLayer
ConvBNLayer类实际上就是实现了一个卷积层+bn层,如名所示。当我们写其他模块需要用到各种卷积层如1x1卷积和depth-wise卷积时,就可以通过传入各种相关参数给它实现我们需要的层。self.conv就是卷积层,通过包括groups在内的参数可以控制卷积层是什么样的。在forward()中,输入x,走过self.conv、self.bn后再根据self.if_act,self.act决定激活函数是什么样子。
class ConvBNLayer(nn.Layer):def __init__(self,in_channels,out_channels,kernel_size,stride,padding,groups=1,if_act=True,act=None):super(ConvBNLayer, self).__init__()self.if_act = if_actself.act = actself.conv = nn.Conv2D(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size,stride=stride,padding=padding,groups=groups,bias_attr=False)self.bn = nn.BatchNorm(num_channels=out_channels, act=None)def forward(self, x):x = self.conv(x)x = self.bn(x)if self.if_act:if self.act == "relu":x = F.relu(x)elif self.act == "hardswish":x = F.hardswish(x)else:print("The activation function({}) is selected incorrectly.".format(self.act))exit()return x
ResidualUnit
这里的ResidualUnit即是论文中的bneck模块,是核心模块,也是网络的基本模块。实现了通道分离卷积+SE通道注意力机制+残差结构。
ResidualUnit调用ConvBNLayer类,实现各种需要的层:self.expand_conv, self.bottleneck_conv, self.linear_conv,并且调用SEModule类实现了self.mid_se层
- self.expand_conv:1x1卷积层,用来降低通道数
- self.bottleneck_conv:可分离卷积中的depth-wise卷积
- self.linear_conv:通道分离卷积中的point-wise卷积,1x1卷积
- self.mid_se:通道注意力机制,SE模块
在forward()中,输入x首先通过1x1卷积降低通道数,再进行通道分离的卷积,此时,如果self.if_se=True,将特征进行通道注意力。接着,将特征通过1x1卷积,这是通道分离卷积的第二个步骤,可以改变通道数,并增加通道之间的信息交互。最后,如果self.if_shortcut=True,将此模块的输入特征inputs与目前的特征x相加形成残差结构。
图中1x1卷积看起来是升维了,而且图中是depth-wise卷积后立即进行通道注意力计算,再进行point-wise卷积。
class ResidualUnit(nn.Layer):def __init__(self,in_channels,mid_channels,out_channels,kernel_size,stride,use_se,act=None):super(ResidualUnit, self).__init__()self.if_shortcut = stride == 1 and in_channels == out_channelsself.if_se = use_se# kernel_size=1说明它是用来降维的1x1卷积self.expand_conv = ConvBNLayer(in_channels=in_channels,out_channels=mid_channels,kernel_size=1,stride=1,padding=0,if_act=True,act=act)# groups=mid_channels说明它是通道可分离卷积中的depth-wise卷积self.bottleneck_conv = ConvBNLayer(in_channels=mid_channels,out_channels=mid_channels,kernel_size=kernel_size,stride=stride,padding=int((kernel_size - 1) // 2),groups=mid_channels,if_act=True,act=act)if self.if_se:self.mid_se = SEModule(mid_channels)# 用1x 1卷积实现fc层self.linear_conv = ConvBNLayer(in_channels=mid_channels,out_channels=out_channels,kernel_size=1,stride=1,padding=0,if_act=False,act=None)def forward(self, inputs):# 先对输入特征进行降维,得到mid-channel的特征x = self.expand_conv(inputs)# 深度可分离卷积x = self.bottleneck_conv(x)# 如果if_se=true,加入激励模块if self.if_se:x = self.mid_se(x)x = self.linear_conv(x)if self.if_shortcut:x = paddle.add(inputs, x)return x
SE
图中是 首先进行global average pooling,即squeeze,再用1x1卷积降低通道数,进行relu,再恢复通道数,sigmoid得到factor,最后channel-wise乘以原特征。
class SEModule(nn.Layer):def __init__(self, in_channels, reduction=4):super(SEModule, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2D(1)self.conv1 = nn.Conv2D(in_channels=in_channels,out_channels=in_channels // reduction,kernel_size=1,stride=1,padding=0)self.conv2 = nn.Conv2D(in_channels=in_channels // reduction,out_channels=in_channels,kernel_size=1,stride=1,padding=0)def forward(self, inputs):outputs = self.avg_pool(inputs)outputs = self.conv1(outputs)outputs = F.relu(outputs)outputs = self.conv2(outputs)outputs = F.hardsigmoid(outputs, slope=0.2, offset=0.5)return inputs * outputs
MobileNetV3
输入首先经过一个3x3,stride=2的卷积层conv1,再经过由基础模块bneck组成的block_list,再经过1x1卷积conv2,最后通过kernel_size=2, stride=2的最大池化进行2倍下采样。
large和small的整体结构一致,区别就是基本单元bneck的个数(large15个,small11个)以及内部参数上,主要是通道数目。
self.conv1: stride=2的3x3卷积层,激活函数为hardswish。它的out_channels由make_divisible()得到
block_list: 首先初始化一个为[]的blocklist,根据cfg中的超参,调用ResidualUnit实现基本模块,依次向blocklist添加基本模块
self.conv2: 调用ConvBNLayer,实现1x1卷积
self.pool:最大池化,kernel_size=2, stride=2,得到2倍下采样的特征
class MobileNetV3(nn.Layer):def __init__(self,in_channels=3,model_name='large',scale=0.5,disable_se=False,**kwargs):"""the MobilenetV3 backbone network for detection module.Args:params(dict): the super parameters for build network"""super(MobileNetV3, self).__init__()self.disable_se = disable_seif model_name == "large":cfg = [# k, exp, c, se, nl, s,[3, 16, 16, False, 'relu', 1],[3, 64, 24, False, 'relu', 2],[3, 72, 24, False, 'relu', 1],[5, 72, 40, True, 'relu', 2],[5, 120, 40, True, 'relu', 1],[5, 120, 40, True, 'relu', 1],[3, 240, 80, False, 'hardswish', 2],[3, 200, 80, False, 'hardswish', 1],[3, 184, 80, False, 'hardswish', 1],[3, 184, 80, False, 'hardswish', 1],[3, 480, 112, True, 'hardswish', 1],[3, 672, 112, True, 'hardswish', 1],[5, 672, 160, True, 'hardswish', 2],[5, 960, 160, True, 'hardswish', 1],[5, 960, 160, True, 'hardswish', 1],]cls_ch_squeeze = 960elif model_name == "small":cfg = [# k, exp, c, se, nl, s,[3, 16, 16, True, 'relu', 2],[3, 72, 24, False, 'relu', 2],[3, 88, 24, False, 'relu', 1],[5, 96, 40, True, 'hardswish', 2],[5, 240, 40, True, 'hardswish', 1],[5, 240, 40, True, 'hardswish', 1],[5, 120, 48, True, 'hardswish', 1],[5, 144, 48, True, 'hardswish', 1],[5, 288, 96, True, 'hardswish', 2],[5, 576, 96, True, 'hardswish', 1],[5, 576, 96, True, 'hardswish', 1],]cls_ch_squeeze = 576else:raise NotImplementedError("mode[" + model_name +"_model] is not implemented!")supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]assert scale in supported_scale, \"supported scale are {} but input scale is {}".format(supported_scale, scale)inplanes = 16# conv1self.conv = ConvBNLayer(in_channels=in_channels,out_channels=make_divisible(inplanes * scale),kernel_size=3,stride=2,padding=1,groups=1,if_act=True,act='hardswish')self.stages = []self.out_channels = []block_list = []i = 0inplanes = make_divisible(inplanes * scale)for (k, exp, c, se, nl, s) in cfg:se = se and not self.disable_sestart_idx = 2 if model_name == 'large' else 0if s == 2 and i > start_idx:self.out_channels.append(inplanes)self.stages.append(nn.Sequential(*block_list))block_list = []block_list.append(ResidualUnit(in_channels=inplanes,mid_channels=make_divisible(scale * exp),out_channels=make_divisible(scale * c),kernel_size=k,stride=s,use_se=se,act=nl))inplanes = make_divisible(scale * c)i += 1block_list.append(ConvBNLayer(in_channels=inplanes,out_channels=make_divisible(scale * cls_ch_squeeze),kernel_size=1,stride=1,padding=0,groups=1,if_act=True,act='hardswish'))self.stages.append(nn.Sequential(*block_list))self.out_channels.append(make_divisible(scale * cls_ch_squeeze))for i, stage in enumerate(self.stages):self.add_sublayer(sublayer=stage, name="stage{}".format(i))def forward(self, x):x = self.conv(x)out_list = []for stage in self.stages:x = stage(x)out_list.append(x)return out_list
网络输出head
图中在1x1conv和最大池化后还有一个无bn层的1x1conv+HS激活以及一个1x1conv,但是本文的网络实现到最大池化为止,没有后面的。以此作为backbone。由于我们进行的是text recognition任务,后面还有其他模块。
MobileNetV3代码和网络相关推荐
- 不要怂,就是GAN (生成式对抗网络) (五):无约束条件的 GAN 代码与网络的 Graph...
GAN 这个领域发展太快,日新月异,各种 GAN 层出不穷,前几天看到一篇关于 Wasserstein GAN 的文章,讲的很好,在此把它分享出来一起学习:https://zhuanlan.zhihu ...
- 从java代码到网络编程
学习目录 前言 一.重温计网 二.从代码到网络 1.InetAddress类 2.Datagram类 2.1DatagramPacket类 2.2DatagramSocket类 三.UDP一发一收模型 ...
- Unity代码调用网络摄像头WebCamTexture
Unity代码调用网络摄像头WebCamTexture 初始代码 后续功能&代码更新 注意事项 初始代码 编辑器模式或Android环境和实机测试都没有问题. using System.Col ...
- 13行MATLAB代码实现网络爬虫 爬取NASA画廊星图
13行MATLAB代码实现网络爬虫 爬取NASA画廊星图 2021/04/18上传 2021/04/21更新:修改N的输入方式,增加对png格式图片的下载支持,增加了自动处理几种错误情况的代码,能够 ...
- C++ 简化 推箱子 小游戏 完整代码 参考网络资料 命令行运行 仅供初学者参考交流
C++ 简化 推箱子 小游戏 完整代码 参考网络资料 命令行运行 仅供初学者参考交流 说明:学做了4关推箱子, 仅供初学者参考可用g++ 编译,可以将内容复制到TXT文件,将后缀改为".cp ...
- 恶意代码分析实战 11 恶意代码的网络特征
11.1 Lab14-01 问题 恶意代码使用了哪些网络库?它们的优势是什么? 使用WireShark进行动态分析. 使用另外的机器进行分析对比可知,User-Agent不是硬编码. 请求的URL值得 ...
- 深入浅出 TCP/IP 协议栈丨手写代码实现网络协议栈
TCP/IP 协议栈是一系列网络协议的总和,是构成网络通信的核心骨架,它定义了电子设备如何连入因特网,以及数据如何在它们之间进行传输.TCP/IP 协议采用4层结构,分别是应用层.传输层.网络层和链路 ...
- 【神经网络】(16) MobileNetV3 代码复现,网络解析,附Tensorflow完整代码
各位同学好,今天和大家分享一下如何使用 Tensorflow 构建 MobileNetV3 轻量化网络模型. MobileNetV3 做了如下改动(1)更新了V2中的逆转残差结构:(2)使用NAS搜索 ...
- 从Openvswitch代码看网络包的旅程
我们知道,Openvwitch可以创建虚拟交换机,而网络包可以通过虚拟交换机进行转发,并通过流表进行处理,具体的过程如何呢? 一.内核模块Openvswitch.ko的加载 OVS是内核态和用户态配合 ...
最新文章
- 用0到9十个数字,每个数字使用一次,构成两个五位数a和b,并且a+20295=b.求a,b
- 013_Slider滑块
- Docker镜像原理学习理解
- queued frame 造成图形性能卡顿
- 条件随机场(Conditional random fields,CRFs)文献阅读指南
- WebRTC Linux ADM 实现中的符号延迟加载机制
- 事务默认的传播属性和事务默认的隔离级别
- Android中添加自己的模块 【转】
- 该如何弥补 GitHub 功能缺陷?
- 【ROS】机器人编程实践
- Intellij idea 添加浏览器
- android安卓字体下载,安卓Android简中综艺手机字体
- Caffe教程:训练自己的网络结构来分类。
- 阿里云CentOS7服务器搭建邮件服务器,端口:465
- 徐培成JAVA基础04
- HTML5框架 iframe用法 实现嵌套 好玩用法
- Foxmail邮箱提示错误:ssl连接错误,errorCode:5解决方法
- mysql sqlyog 乱码_SQLyog中文乱码的解决方法
- python3 获取整分钟数的时间,如间隔半小时
- MobPush 创建推送