文章目录

  • ResNet - 残差网络
    • 定义残差块(Residual)
    • ResNet模型
    • 训练模型
    • 小结

ResNet - 残差网络

关于ResNet残差网络,最本质且主要的公式如下:

f(x)=g(x)+xf(x) = g(x) + x f(x)=g(x)+x

可以认为 f(x)f(x)f(x) 是最终残差网络的输出, g(x)g(x)g(x) 是残差网络中两次卷积的输出, xxx 是样本数据集

一个残差块的主要结构如下图所示:

下面我们来先定义一个残差块Residual。

定义残差块(Residual)

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l"""
定义残差网络
每个残差块的具体逻辑:
1、3*3卷积层操作
2、批量规范化
3、relu激活函数
4、3*3卷积层操作
5、批量规范化
称以上5步的操作为f(x)函数
若未指定1*1卷积操作,则输出返回 x + f(x)
否则返回 conv3(x) + f(x), 其中conv3()代表1*1卷积层
"""
class Residual(nn.Module):  #@savedef __init__(self, input_channels, num_channels,use_1x1conv=False, strides=1):super().__init__()self.conv1 = nn.Conv2d(input_channels, num_channels,kernel_size=3, padding=1, stride=strides)                 #3*3卷积操作self.conv2 = nn.Conv2d(num_channels, num_channels,kernel_size=3, padding=1)                                 #3*3卷积操作if use_1x1conv:self.conv3 = nn.Conv2d(input_channels, num_channels,                         #是否使用1*1卷积层操作kernel_size=1, stride=strides)else:self.conv3 = Noneself.bn1 = nn.BatchNorm2d(num_channels)self.bn2 = nn.BatchNorm2d(num_channels)                                          #两个批量规范化操作def forward(self, X):Y = F.relu(self.bn1(self.conv1(X)))Y = self.bn2(self.conv2(Y))                                                      #对样本数据进行两次卷积操作,得到g(x)if self.conv3: X = self.conv3(X)Y += X                                                                           #加上x,即 f(x) = g(x) + xreturn F.relu(Y)

该代码会生成两种类型的网络:

1.当use_1x1conv=False时,应用ReLU非线性函数之前,将输入添加到输出。

2.当use_1x1conv=True时,添加通过 1×11 \times 11×1 卷积调整通道和分辨率。

下面我们来查看输入和输出形状一致的情况。

#注意,当未使用1*1卷积层时,输入通道数和输出通道数要保持一致,否则会出现 X 与 Y 形状不一致相加出现错误的现象
blk = Residual(3, 3)
X = torch.rand(4, 3, 6, 6)                          #定义X数据集为4个样本数,3个通道,每个图片为 6*6
Y = blk(X)
Y.shape
torch.Size([4, 3, 6, 6])
blk = Residual(3, 6, use_1x1conv=True, strides=2)
blk(X).shape
torch.Size([4, 6, 3, 3])

ResNet模型

定义b1环节模型,包含一个 7×77 \times 77×7 的卷积层、批量规范化层、relu激活函数、最大汇聚层(池化层)。

b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(64), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

ResNet使用4个由残差块组成的模块,每个模块使用若干个同样输出通道数的残差块。 第一个模块的通道数同输入通道数一致。 由于之前已经使用了步幅为2的最大汇聚层,所以无须减小高和宽。 之后的每个模块在第一个残差块里将上一个模块的通道数翻倍,并将高和宽减半

下面我们来实现这个模块。注意,我们对第一个模块做了特别处理。


#定义残差块,输入参数分别为输入、输出通道数,残差网络数目
def resent_block(input_channels, num_channels, num_residuals, first_block=False):blk = []                     #定义残差网络列表for i in range(num_residuals):if i == 0 and not first_block:#如果是第一个残差网络,则将宽高减半blk.append(Residual(input_channels, num_channels, use_1x1conv=True, strides=2))else:#后续的残差网络blk.append(Residual(num_channels, num_channels))return blk

接着在ResNet加入所有残差块,这里残差块存在两个残差网络。

b2 = nn.Sequential(*resent_block(64, 64, 2, first_block=True))
b3 = nn.Sequential(*resent_block(64, 128, 2))
b4 = nn.Sequential(*resent_block(128, 256, 2))
b5 = nn.Sequential(*resent_block(256, 512, 2))

最后,在ResNet中加入全局平均汇聚层,以及全连接层输出

net = nn.Sequential(b1, b2, b3, b4, b5,nn.AdaptiveAvgPool2d((1, 1)),nn.Flatten(), nn.Linear(512, 10))

每个模块有4个卷积层(不包括恒等映射的 1×11 \times 11×1 卷积层)。 加上第一个 7×77 \times 77×7 卷积层和最后一个全连接层,共有18层。 因此,这种模型通常被称为ResNet-18。 通过配置不同的通道数和模块里的残差块数可以得到不同的ResNet模型,例如更深的含152层的ResNet-152。 虽然ResNet的主体架构跟GoogLeNet类似,但ResNet架构更简单,修改也更方便。这些因素都导致了ResNet迅速被广泛使用。 下图描述了完整的ResNet-18。

现在我们来测试下网络的结构

X = torch.rand(size=(1, 1, 224, 224))
for layer in net:X = layer(X)print(layer.__class__.__name__, 'output shape:\t', X.shape)
Sequential output shape:  torch.Size([1, 64, 56, 56])
Sequential output shape:     torch.Size([1, 64, 56, 56])
Sequential output shape:     torch.Size([1, 128, 28, 28])
Sequential output shape:     torch.Size([1, 256, 14, 14])
Sequential output shape:     torch.Size([1, 512, 7, 7])
AdaptiveAvgPool2d output shape:  torch.Size([1, 512, 1, 1])
Flatten output shape:    torch.Size([1, 512])
Linear output shape:     torch.Size([1, 10])

训练模型

同之前一样,我们在Fashion-MNIST数据集上训练ResNet。

lr, num_epochs, batch_size = 0.05, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

可见,训练集的精确度为 0.996,测试集的精确度为 0.919,性能较好。

小结

1.学习嵌套函数(nested function) 是训练神经网络的理想情况。

2.残差映射可以更容易地学习同一函数,例如将权重层中的参数近似为零。

3.利用残差块(residual blocks)可以训练出一个有效的深层神经网络:输入可以通过层间的残余连接更快地向前传播

4.残差网络(ResNet)对随后的深层神经网络设计产生了深远影响。

ResNet - 残差神经网络(CNN卷积神经网络)相关推荐

  1. DenseNet - 稠密神经网络(CNN卷积神经网络)

    文章目录 DenseNet - 稠密神经网络 稠密块体 稠密块中的卷积层 稠密块 过渡层 DenseNet模型 训练模型 小结 DenseNet - 稠密神经网络 ResNet极大地改变了如何参数化深 ...

  2. 深度学习之CNN卷积神经网络

    详解卷积神经网络(CNN) 卷积神经网络(Convolutional Neural Network, CNN)是一种前馈神经网络,它的人工神经元可以响应一部分覆盖范围内的周围单元,对于大型图像处理有出 ...

  3. 【卷积神经网络】卷积神经网络(Convolutional Neural Networks, CNN)基础

    卷积神经网络(Convolutional Neural Networks, CNN),是一种 针对图像 的特殊的 神经网络. 卷积神经网络概述 Why not DNN? 图像数据的维数很高,比如 1, ...

  4. 搭建CNN卷积神经网络(用pytorch搭建)

    手撕卷积神经网络-CNN 卷积:提取特征 池化:压缩特征 heigh X weigh X depth 长度 宽度.深度(也就是特征图个数) 例如输入32x32x3 hxwxc 卷积就是取某个小区域进行 ...

  5. CNN卷积神经网络之RegNet

    CNN卷积神经网络之RegNet 前言 设计思路 AnyNet设计空间 网络结构 实验结果 消融实验结论 前言 <Designing Network Design Spaces> 论文地址 ...

  6. tf2.0先试试图片(七)——CNN卷积神经网络

    之前已经介绍了TenforFlow的基本操作和神经网络,主要是全联接网络的一些概念: tf2.0先试试图片(七)--CNN卷积神经网络 7.0 简介 7.1 全连接网络的问题 7.1.1 局部相关性 ...

  7. CNN卷积神经网络之SENet及代码

    CNN卷积神经网络之SENet 个人成果,禁止以任何形式转载或抄袭! 一.前言 二.SE block细节 SE block的运用实例 模型的复杂度 三.消融实验 1.降维系数r 2.Squeeze操作 ...

  8. CNN卷积神经网络详解

    1.cnn卷积神经网络的概念 卷积神经网络(CNN),这是深度学习算法应用最成功的领域之一,卷积神经网络包括一维卷积神经网络,二维卷积神经网络以及三维卷积神经网络.一维卷积神经网络主要用于序列类的数据 ...

  9. BP神经网络与卷积神经网络(CNN)

    BP神经网络与卷积神经网络(CNN) 1.BP神经网络  1.1 神经网络基础  神经网络的基本组成单元是神经元.神经元的通用模型如图 1所示,其中常用的激活函数有阈值函数.sigmoid函数和双曲正 ...

最新文章

  1. ML之LiR:机器学习经典算法之线性回归算法LiR的简介、使用方法、经典案例之详细攻略
  2. 修改Bootstrap的一些默认样式
  3. JZOJ 3600. 【CQOI2014】通配符匹配
  4. struts.xml中class路径错误报错的问题
  5. XML基础——extensible markup language
  6. HIve内置函数(functions)使用和解析
  7. 6.网络层(4)---IP多播,NAT
  8. 此男因为什么被送进医院?
  9. 西南科技大学OJ题 约瑟夫问题的实现0956
  10. C++面向对象课程设计实例-图书馆借阅系统
  11. 贴片led极性_贴片发光二极管正负极判断方法详解
  12. python爬取通过百度图片搜出来的所有图片
  13. 慕尼黑工业大学计算机博士申请条件,慕尼黑大学博士条件
  14. 如何使用商品历史价格查询网站
  15. LWN:Linux audio plugin APIs综述!
  16. Pytorch实现人脸多属性识别
  17. 使用MATLAB2014a将灰度图转为彩色图
  18. 基于SSH的校园二手物品交易系统
  19. Python代码在Pycharm中不起作用,但在Jupiter Notebook中执行良好
  20. P5551 Chino的树学

热门文章

  1. 能力风暴机器人AS-MF2011小试身手
  2. 信息技术用计算机画画教学设计,小学四年级信息技术用计算机画画教学设计
  3. How to reply when sb say Thank you to you
  4. 大数据处理技术导论(6) | Datawhale组队学习46期
  5. 无纸化测评计算机基础知识,《计算机应用基础》自考计算机无纸化考试
  6. 计算机基础知识学员评价,大学计算机基础课程评价的模式的探讨.doc
  7. PMP证书的有效期是多久?
  8. 2021年中国禽蛋行业现状分析:禽蛋产量同比下降1.7%[图]
  9. 【杂文】话说红颜知已
  10. 电脑BIOS 设置怎样从光盘(USB优盘)启动