MobileNetV3实现

参考博客:
https://zhuanlan.zhihu.com/p/323346888
https://blog.csdn.net/Chunfengyanyulove/article/details/91358187

来自paddelocr中的MobileNetV3,实现网络的类如下。

ConvBNLayer

ConvBNLayer类实际上就是实现了一个卷积层+bn层,如名所示。当我们写其他模块需要用到各种卷积层如1x1卷积和depth-wise卷积时,就可以通过传入各种相关参数给它实现我们需要的层。self.conv就是卷积层,通过包括groups在内的参数可以控制卷积层是什么样的。在forward()中,输入x,走过self.conv、self.bn后再根据self.if_act,self.act决定激活函数是什么样子。

class ConvBNLayer(nn.Layer):def __init__(self,in_channels,out_channels,kernel_size,stride,padding,groups=1,if_act=True,act=None):super(ConvBNLayer, self).__init__()self.if_act = if_actself.act = actself.conv = nn.Conv2D(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size,stride=stride,padding=padding,groups=groups,bias_attr=False)self.bn = nn.BatchNorm(num_channels=out_channels, act=None)def forward(self, x):x = self.conv(x)x = self.bn(x)if self.if_act:if self.act == "relu":x = F.relu(x)elif self.act == "hardswish":x = F.hardswish(x)else:print("The activation function({}) is selected incorrectly.".format(self.act))exit()return x

ResidualUnit

这里的ResidualUnit即是论文中的bneck模块,是核心模块,也是网络的基本模块。实现了通道分离卷积+SE通道注意力机制+残差结构。

ResidualUnit调用ConvBNLayer类,实现各种需要的层:self.expand_conv, self.bottleneck_conv, self.linear_conv,并且调用SEModule类实现了self.mid_se层

  • self.expand_conv:1x1卷积层,用来降低通道数
  • self.bottleneck_conv:可分离卷积中的depth-wise卷积
  • self.linear_conv:通道分离卷积中的point-wise卷积,1x1卷积
  • self.mid_se:通道注意力机制,SE模块

在forward()中,输入x首先通过1x1卷积降低通道数,再进行通道分离的卷积,此时,如果self.if_se=True,将特征进行通道注意力。接着,将特征通过1x1卷积,这是通道分离卷积的第二个步骤,可以改变通道数,并增加通道之间的信息交互。最后,如果self.if_shortcut=True,将此模块的输入特征inputs与目前的特征x相加形成残差结构。


图中1x1卷积看起来是升维了,而且图中是depth-wise卷积后立即进行通道注意力计算,再进行point-wise卷积。

class ResidualUnit(nn.Layer):def __init__(self,in_channels,mid_channels,out_channels,kernel_size,stride,use_se,act=None):super(ResidualUnit, self).__init__()self.if_shortcut = stride == 1 and in_channels == out_channelsself.if_se = use_se# kernel_size=1说明它是用来降维的1x1卷积self.expand_conv = ConvBNLayer(in_channels=in_channels,out_channels=mid_channels,kernel_size=1,stride=1,padding=0,if_act=True,act=act)# groups=mid_channels说明它是通道可分离卷积中的depth-wise卷积self.bottleneck_conv = ConvBNLayer(in_channels=mid_channels,out_channels=mid_channels,kernel_size=kernel_size,stride=stride,padding=int((kernel_size - 1) // 2),groups=mid_channels,if_act=True,act=act)if self.if_se:self.mid_se = SEModule(mid_channels)# 用1x 1卷积实现fc层self.linear_conv = ConvBNLayer(in_channels=mid_channels,out_channels=out_channels,kernel_size=1,stride=1,padding=0,if_act=False,act=None)def forward(self, inputs):# 先对输入特征进行降维,得到mid-channel的特征x = self.expand_conv(inputs)# 深度可分离卷积x = self.bottleneck_conv(x)# 如果if_se=true,加入激励模块if self.if_se:x = self.mid_se(x)x = self.linear_conv(x)if self.if_shortcut:x = paddle.add(inputs, x)return x

SE

图中是 首先进行global average pooling,即squeeze,再用1x1卷积降低通道数,进行relu,再恢复通道数,sigmoid得到factor,最后channel-wise乘以原特征。

class SEModule(nn.Layer):def __init__(self, in_channels, reduction=4):super(SEModule, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2D(1)self.conv1 = nn.Conv2D(in_channels=in_channels,out_channels=in_channels // reduction,kernel_size=1,stride=1,padding=0)self.conv2 = nn.Conv2D(in_channels=in_channels // reduction,out_channels=in_channels,kernel_size=1,stride=1,padding=0)def forward(self, inputs):outputs = self.avg_pool(inputs)outputs = self.conv1(outputs)outputs = F.relu(outputs)outputs = self.conv2(outputs)outputs = F.hardsigmoid(outputs, slope=0.2, offset=0.5)return inputs * outputs

MobileNetV3

输入首先经过一个3x3,stride=2的卷积层conv1,再经过由基础模块bneck组成的block_list,再经过1x1卷积conv2,最后通过kernel_size=2, stride=2的最大池化进行2倍下采样。

large和small的整体结构一致,区别就是基本单元bneck的个数(large15个,small11个)以及内部参数上,主要是通道数目。

  • self.conv1: stride=2的3x3卷积层,激活函数为hardswish。它的out_channels由make_divisible()得到

  • block_list: 首先初始化一个为[]的blocklist,根据cfg中的超参,调用ResidualUnit实现基本模块,依次向blocklist添加基本模块

  • self.conv2: 调用ConvBNLayer,实现1x1卷积

  • self.pool:最大池化,kernel_size=2, stride=2,得到2倍下采样的特征

class MobileNetV3(nn.Layer):def __init__(self,in_channels=3,model_name='large',scale=0.5,disable_se=False,**kwargs):"""the MobilenetV3 backbone network for detection module.Args:params(dict): the super parameters for build network"""super(MobileNetV3, self).__init__()self.disable_se = disable_seif model_name == "large":cfg = [# k, exp, c,  se,     nl,  s,[3, 16, 16, False, 'relu', 1],[3, 64, 24, False, 'relu', 2],[3, 72, 24, False, 'relu', 1],[5, 72, 40, True, 'relu', 2],[5, 120, 40, True, 'relu', 1],[5, 120, 40, True, 'relu', 1],[3, 240, 80, False, 'hardswish', 2],[3, 200, 80, False, 'hardswish', 1],[3, 184, 80, False, 'hardswish', 1],[3, 184, 80, False, 'hardswish', 1],[3, 480, 112, True, 'hardswish', 1],[3, 672, 112, True, 'hardswish', 1],[5, 672, 160, True, 'hardswish', 2],[5, 960, 160, True, 'hardswish', 1],[5, 960, 160, True, 'hardswish', 1],]cls_ch_squeeze = 960elif model_name == "small":cfg = [# k, exp, c,  se,     nl,  s,[3, 16, 16, True, 'relu', 2],[3, 72, 24, False, 'relu', 2],[3, 88, 24, False, 'relu', 1],[5, 96, 40, True, 'hardswish', 2],[5, 240, 40, True, 'hardswish', 1],[5, 240, 40, True, 'hardswish', 1],[5, 120, 48, True, 'hardswish', 1],[5, 144, 48, True, 'hardswish', 1],[5, 288, 96, True, 'hardswish', 2],[5, 576, 96, True, 'hardswish', 1],[5, 576, 96, True, 'hardswish', 1],]cls_ch_squeeze = 576else:raise NotImplementedError("mode[" + model_name +"_model] is not implemented!")supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]assert scale in supported_scale, \"supported scale are {} but input scale is {}".format(supported_scale, scale)inplanes = 16# conv1self.conv = ConvBNLayer(in_channels=in_channels,out_channels=make_divisible(inplanes * scale),kernel_size=3,stride=2,padding=1,groups=1,if_act=True,act='hardswish')self.stages = []self.out_channels = []block_list = []i = 0inplanes = make_divisible(inplanes * scale)for (k, exp, c, se, nl, s) in cfg:se = se and not self.disable_sestart_idx = 2 if model_name == 'large' else 0if s == 2 and i > start_idx:self.out_channels.append(inplanes)self.stages.append(nn.Sequential(*block_list))block_list = []block_list.append(ResidualUnit(in_channels=inplanes,mid_channels=make_divisible(scale * exp),out_channels=make_divisible(scale * c),kernel_size=k,stride=s,use_se=se,act=nl))inplanes = make_divisible(scale * c)i += 1block_list.append(ConvBNLayer(in_channels=inplanes,out_channels=make_divisible(scale * cls_ch_squeeze),kernel_size=1,stride=1,padding=0,groups=1,if_act=True,act='hardswish'))self.stages.append(nn.Sequential(*block_list))self.out_channels.append(make_divisible(scale * cls_ch_squeeze))for i, stage in enumerate(self.stages):self.add_sublayer(sublayer=stage, name="stage{}".format(i))def forward(self, x):x = self.conv(x)out_list = []for stage in self.stages:x = stage(x)out_list.append(x)return out_list

网络输出head

图中在1x1conv和最大池化后还有一个无bn层的1x1conv+HS激活以及一个1x1conv,但是本文的网络实现到最大池化为止,没有后面的。以此作为backbone。由于我们进行的是text recognition任务,后面还有其他模块。

MobileNetV3代码和网络相关推荐

  1. 不要怂,就是GAN (生成式对抗网络) (五):无约束条件的 GAN 代码与网络的 Graph...

    GAN 这个领域发展太快,日新月异,各种 GAN 层出不穷,前几天看到一篇关于 Wasserstein GAN 的文章,讲的很好,在此把它分享出来一起学习:https://zhuanlan.zhihu ...

  2. 从java代码到网络编程

    学习目录 前言 一.重温计网 二.从代码到网络 1.InetAddress类 2.Datagram类 2.1DatagramPacket类 2.2DatagramSocket类 三.UDP一发一收模型 ...

  3. Unity代码调用网络摄像头WebCamTexture

    Unity代码调用网络摄像头WebCamTexture 初始代码 后续功能&代码更新 注意事项 初始代码 编辑器模式或Android环境和实机测试都没有问题. using System.Col ...

  4. 13行MATLAB代码实现网络爬虫 爬取NASA画廊星图

    13行MATLAB代码实现网络爬虫 爬取NASA画廊星图 2021/04/18​上传 2021/04/21更新:修改N的输入方式,增加对png格式图片的下载支持,增加了自动处理几种错误情况的代码,能够 ...

  5. C++ 简化 推箱子 小游戏 完整代码 参考网络资料 命令行运行 仅供初学者参考交流

    C++ 简化 推箱子 小游戏 完整代码 参考网络资料 命令行运行 仅供初学者参考交流 说明:学做了4关推箱子, 仅供初学者参考可用g++ 编译,可以将内容复制到TXT文件,将后缀改为".cp ...

  6. 恶意代码分析实战 11 恶意代码的网络特征

    11.1 Lab14-01 问题 恶意代码使用了哪些网络库?它们的优势是什么? 使用WireShark进行动态分析. 使用另外的机器进行分析对比可知,User-Agent不是硬编码. 请求的URL值得 ...

  7. 深入浅出 TCP/IP 协议栈丨手写代码实现网络协议栈

    TCP/IP 协议栈是一系列网络协议的总和,是构成网络通信的核心骨架,它定义了电子设备如何连入因特网,以及数据如何在它们之间进行传输.TCP/IP 协议采用4层结构,分别是应用层.传输层.网络层和链路 ...

  8. 【神经网络】(16) MobileNetV3 代码复现,网络解析,附Tensorflow完整代码

    各位同学好,今天和大家分享一下如何使用 Tensorflow 构建 MobileNetV3 轻量化网络模型. MobileNetV3 做了如下改动(1)更新了V2中的逆转残差结构:(2)使用NAS搜索 ...

  9. 从Openvswitch代码看网络包的旅程

    我们知道,Openvwitch可以创建虚拟交换机,而网络包可以通过虚拟交换机进行转发,并通过流表进行处理,具体的过程如何呢? 一.内核模块Openvswitch.ko的加载 OVS是内核态和用户态配合 ...

最新文章

  1. 用0到9十个数字,每个数字使用一次,构成两个五位数a和b,并且a+20295=b.求a,b
  2. 013_Slider滑块
  3. Docker镜像原理学习理解
  4. queued frame 造成图形性能卡顿
  5. 条件随机场(Conditional random fields,CRFs)文献阅读指南
  6. WebRTC Linux ADM 实现中的符号延迟加载机制
  7. 事务默认的传播属性和事务默认的隔离级别
  8. Android中添加自己的模块 【转】
  9. 该如何弥补 GitHub 功能缺陷?
  10. 【ROS】机器人编程实践
  11. Intellij idea 添加浏览器
  12. android安卓字体下载,安卓Android简中综艺手机字体
  13. Caffe教程:训练自己的网络结构来分类。
  14. 阿里云CentOS7服务器搭建邮件服务器,端口:465
  15. 徐培成JAVA基础04
  16. HTML5框架 iframe用法 实现嵌套 好玩用法
  17. Foxmail邮箱提示错误:ssl连接错误,errorCode:5解决方法
  18. mysql sqlyog 乱码_SQLyog中文乱码的解决方法
  19. python3 获取整分钟数的时间,如间隔半小时
  20. MobPush 创建推送

热门文章

  1. 使用.net开发手机管理软件 (五) OBEX介绍
  2. 使用QT定时器遇到的问题
  3. 一个魔兽宅男的现实生活
  4. EasyClick IOS 自动化测试问题处理
  5. 电脑开机花屏的几种情况和解决方法 --旧时光 美剧 oldtimeblog
  6. 越狱手机如何让APP绕过越狱检测
  7. @Nullable注释用法
  8. WPF 使用RelativeSource绑定
  9. 清微智能恭祝您元宵佳节阖家欢乐!
  10. 【数据科学家】每个数据科学家都应该学习4个必备技能