背景介绍:MNIST数据集识别黑白的手写数字图片,不适合彩色模型的RGB三通道图片。用深度残差网络学习多通道图片。

简单介绍一下深度残差网络:普通的深度网络随着网络深度的加深,拟合效果可能会越来越好,也可能会变差,换句话说在不停地学习,但是有可能学歪了。本次介绍的深度残差网络最后输出H(x)=x+f(x)。其中x是本层网络的输入,f(x)是本层网络的输出,H(x)是最终得到的结果。由以上公式可以表明,最终结果包含输入x,也就是说不论怎么学习,起码效果不会变差,不会学歪。x和f(x)之间的变换的网络层就被成为残差模块。

有不懂的地方可以看代码下面的解释与讲解

目录

1.残差模块类的构建:

(1)残差模块:

(2)在残差模块中前向传播:

2.残差网络类:

(1)通用残差网络实现:

(2)调用残差模块建立方法:

(3)在整个网络中向前传播:

4.形成特殊网络:


1.残差模块类的构建:

(1)残差模块:

 # 残差模块def __init__(self, filter_num, stride=1):super(BasicBlock, self).__init__()#输入x经过两个卷积层得到f(x),f(x)+x=H(x),对应元素相加得到残差模块H(x)# 第一个卷积单元 卷积核大小3*3是超参数,需要学习,自己制定self.conv1 = layers.Conv2D(filter_num, (3, 3), strides=stride, padding='same')self.bn1 = layers.BatchNormalization()self.relu = layers.Activation('relu')# 第二个卷积单元self.conv2 = layers.Conv2D(filter_num, (3, 3), strides=1, padding='same')self.bn2 = layers.BatchNormalization()#当x与f(x)形状不同的时候,无法进行相加,新建identity(x)卷积层,完成x的形状转换if stride != 1:# 步长不为1,需要通过1x1卷积完成shape匹配self.downsample = Sequential()self.downsample.add(layers.Conv2D(filter_num, (1, 1), strides=stride))else:# shape匹配,直接短接self.downsample = lambda x:x

再重复一遍残差模块是x和f(x)之间的网络变换,包括两个卷积单元。

第一个卷积单元:卷积核的数量由传入参数给定,使用3*3卷积核,步长设定为1,经过卷积变换后形状不变;经过BN层主要对参数进行标准化,对网络有益;最后经过激活层。第二个卷积单元同上,不过不需要激活函数了。

我们上面说过经过残差模块输出f(x)需要与x相加,因此需要保证二者形状相同。如果shape不相同:用1*1的卷积核对矩阵通道数进行变换。在此细说一下x与f(x)的形状:由于padding都是same因此矩阵形状是保持不变的,但是由于卷积层有多个卷积核,则导致最终的矩阵通道维数和卷积核数量一样。因此f(x)和x仅仅在通道维度上不同,则使用1*1卷积核变换。如果shape相同:直接拼接就行。

(2)在残差模块中前向传播:

    def call(self, inputs, training=None):#向前传播# [b, h, w, c],通过第一个卷积单元out = self.conv1(inputs)out = self.bn1(out)out = self.relu(out)# 通过第二个卷积单元out = self.conv2(out)out = self.bn2(out)# 通过identity模块,进行identity转换identity = self.downsample(inputs)# 2条路径输出直接相加;out-f(x),identity-x,实现f(x)+xoutput = layers.add([out, identity])output = tf.nn.relu(output) # 激活函数return output

原始输入x经过两个卷积单元一层一层输出。注意identity模块,需要进行通道数调整,因此输入不是上一个输出,而是原始输入x,要将x的shape修改为f(x)的shape,进行相加。另外注意第二个卷积单元的relu函数从残差模块中提取出来了,放在了H(x)后面,当然这个是不固定的,也可以放在第二个卷积单元内部。

2.残差网络类:

(1)通用残差网络实现:

def __init__(self, layer_dims, num_classes=10): # [2, 2, 2, 2]super(ResNet, self).__init__()# 根网络,预处理    在这个容器中经过卷积层,标准化层,激活函数,池化层减半self.stem = Sequential([layers.Conv2D(64, (3, 3), strides=(1, 1)),layers.BatchNormalization(),layers.Activation('relu'),layers.MaxPool2D(pool_size=(2, 2), strides=(1, 1), padding='same')])# 堆叠4个Block,每个block包含了多个BasicBlock,设置步长不一样self.layer1 = self.build_resblock(64,  layer_dims[0])self.layer2 = self.build_resblock(128, layer_dims[1], stride=2)self.layer3 = self.build_resblock(256, layer_dims[2], stride=2)self.layer4 = self.build_resblock(512, layer_dims[3], stride=2)# 通过Pooling层将高宽降低为1x1self.avgpool = layers.GlobalAveragePooling2D()# 最后连接一个全连接层分类self.fc = layers.Dense(num_classes)

先经过一个容器,不是残差网络的根网络:包括卷积层-BN层-激活层-池化层。

下面是四个残差层,每个残差层都利用build_resblock函数进行残差网络层的建立,并传入参数。

通过池化层完成高宽的转化,最后连接一个全连接层,转化为十个属性的输出,判断结果到底是什么。

(2)调用残差模块建立方法:

#通过该函数一次完成多个残差模块的建立def build_resblock(self, filter_num, blocks, stride=1):# 辅助函数,堆叠filter_num个BasicBlockres_blocks = Sequential()# 只有第一个BasicBlock的步长可能不为1,实现下采样res_blocks.add(BasicBlock(filter_num, stride))for _ in range(1, blocks):#其他BasicBlock步长都为1res_blocks.add(BasicBlock(filter_num, stride=1))return res_blocks

调用残差模块类,传入卷积核数量,步长。blocks表示要建立的残差模块的数量。这些参数都由调用该方法的代码传入。

该方法的调用完成表明残差神经网络构建完成。

(3)在整个网络中向前传播:

    def call(self, inputs, training=None):# 通过根网络x = self.stem(inputs)# 一次通过4个模块x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)# 通过池化层x = self.avgpool(x)# 通过全连接层x = self.fc(x)return x

这一部分注意区别于残差模块中的向前传播,先通过根网络,再逐步通过每一个残差模块,最后经过池化层和全连接层得到输出。

4.形成特殊网络:

def resnet18():# 通过调整模块内部BasicBlock的数量和配置实现不同的ResNetreturn ResNet([2, 2, 2, 2])def resnet34():# 通过调整模块内部BasicBlock的数量和配置实现不同的ResNetreturn ResNet([3, 4, 6, 3])

ResNet18直接调用产生网络的方法,传入参数[2,2,2,2],表明layer_dims是个一维矩阵,元素都是2。ResNet18是指17层卷积层,1层全连接层的网络。每个layer都调用两次建立残差模块的方法,每个残差模块有两个卷积单元也就是两个卷积层,如此一来四个layer,就有16个卷积层,再加上根网络的一个卷积层和最后的一个全连接层刚好是17+1。

机器学习-卷积神经网络之深度残差网络(三)相关推荐

  1. AI(1)认知 人工智能、机器学习、神经网络、深度学习。

    宽为限 紧用功 功夫到 滞塞通 开篇 AI领域是个水很深的新领域,对于非科学研究专业人士来说更是深不可测.选择自己喜欢的学科,兴趣是最好的老师,攻克下去总会有意想不到的收获.AI时代,我们要更加努力! ...

  2. 经典卷积神经网络(二):VGG-Nets、Network-In-Network和深度残差网络

    上一节我们介绍了LeNet-5和AlexNet网络,本节我们将介绍VGG-Nets.Network-In-Network和深度残差网络(residual network). VGG-Nets网络模型 ...

  3. 深度学习之卷积神经网络(12)深度残差网络

    深度学习之卷积神经网络(12)深度残差网络 ResNet原理 ResBlock实现 AlexNet.VGG.GoogleLeNet等网络模型的出现将神经网络的法阵带入了几十层的阶段,研究人员发现网络的 ...

  4. 04.卷积神经网络 W2.深度卷积网络:实例探究(作业:Keras教程+ResNets残差网络)

    文章目录 作业1:Keras教程 1. 快乐的房子 2. 用Keras建模 3. 用你的图片测试 4. 一些有用的Keras函数 作业2:残差网络 Residual Networks 1. 深层神经网 ...

  5. 深度学习 --- 卷积神经网络CNN(LeNet-5网络详解)

    卷积神经网络(Convolutional Neural Network,CNN)是一种前馈型的神经网络,其在大型图像处理方面有出色的表现,目前已经被大范围使用到图像分类.定位等领域中.相比于其他神经网 ...

  6. 深度残差网络_深度残差收缩网络:(三) 网络结构

    1. 回顾一下深度残差网络的结构 在下图中,(a)-(c)分别是三种残差模块,(d)是深度残差网络的整体示意图.BN指的是批标准化(Batch Normalization),ReLU指的是整流线性单元 ...

  7. 基于FPGA的一维卷积神经网络CNN的实现(三)训练网络搭建及参数导出(附代码)

    训练网络搭建 环境:Pytorch,Pycham,Matlab. 说明:该网络反向传播是通过软件方式生成,FPGA内部不进行反向传播计算. 该节通过Python获取训练数据集,并通过Pytorch框架 ...

  8. 深度残差网络的无人机多目标识别

    深度残差网络的无人机多目标识别 人工智能技术与咨询 来源:<图学学报>.作者翟进有等 摘要:传统目标识别算法中,经典的区域建议网络(RPN)在提取目标候选区域时计算量大,时间复杂度较高,因 ...

  9. 深度残差网络ResNet解析

    ResNet在2015年被提出,在ImageNet比赛classification任务上获得第一名,因为它"简单与实用"并存,之后很多方法都建立在ResNet50或者ResNet1 ...

最新文章

  1. 2021全球最具影响力 AI 学者榜单: 中国占比11.1%,位列第二
  2. R语言Wilcoxon Signed-rank统计分布函数(dsignrank, psignrank, qsignrank rsignrank )实战
  3. Hugo + github 搭建个人博客
  4. 【python数据挖掘课程】二十九.数据预处理之字符型转换数值型、标准化、归一化处理
  5. 依赖插件版本冲突问题
  6. 小学奥数 7651 自来水供给 python
  7. 为什么老板给 ta 升职加薪?
  8. java dns 解析域名解析_java网络学习 java dns 域名解析协议实现
  9. IT行业里有这么多聪明人,他们之间的区别在哪里?
  10. ubuntu下Xmodmap映射Esc和Ctrl_L
  11. 网络安全—社会工程学
  12. Shiro 入门教程
  13. 51单片机系统板/开发板原理图以及烧写方法
  14. 电子表整点报时怎么取消_双11红包雨入口在哪 双十一秒杀券怎么抢
  15. websocket实现聊天室(一)
  16. 硬盘加密数据怎么恢复?BitLocker加密文件可恢复吗?BitLocker加密数据怎么恢复?
  17. Android上的Dalvik虚拟机
  18. Day7 零基础python入门100天Udemy训练营-Hangman Game 继续学习import, if else, while loop, for loop
  19. CSDN博文周刊第一期 | 2018年总结:向死而生,为爱而活——忆编程青椒的戎马岁月
  20. Linux NIS服务

热门文章

  1. Java中接口继承接口
  2. PHPMailer实现QQ邮箱发送邮件
  3. 商用密码领域骨干企业格尔软件加入龙蜥社区,共建信息安全底座
  4. Grab Cut与Graph Cut
  5. Opencv中的GrabCut图像分割
  6. C程序设计的抽象思维 pdf
  7. Kaggle项目之PUBG Finish Placement Prediction(一)——探索性分析
  8. UOS安装 .exe 应用
  9. javaWeb-jQuery
  10. MATLAB代码:综合能源系统能源交易模拟与博弈