1.什么是ResNet

ResNet由何凯明、孙剑博士等人于2015年提出,主要解决深层网络难以训练的问题。由于梯度消失和梯度爆炸的存在(初始归一化和中间层归一化一定程度解决了这个问题,使得具有十数层的网络收敛,以进行反向传播的梯度下降),当面对更深层次的网络时,退化问题就会暴露,随着网络深度的增加,精度达到饱和,然后迅速退化。这种退化不是由过度拟合引起的,而是由于更高的训练误差。

训练精度的下降表明并非所有系统都同样容易优化。让我们考虑一个较浅的架构和一个较深的对应架构,并在其上添加更多层。通过构造深层模型,存在一种解决方案:添加的层是identity mapping,其他层是从已学习的较浅模型复制的。该构造解的存在表明,较深的模型不应比较浅的模型产生更高的训练误差。 但是现有的求解器无法实现收敛

作者通过深度残差网络来解决退化问题,不希望每个堆叠层直接适合所需的底层映射,而是明确地让这些层适合剩余映射。将所需的底层映射表示为H(x),我们让堆叠的非线性层拟合另一个映射F(x):= H(x) - x,则原始的映射变为F(x) + x,可以通过有“shortcut connections”的前馈网络实现,如下图

正如之前所讨论的那样,如果添加的层可以构造为identity mapping,则较深的模型的训练误差应不大于较浅的模型。退化问题表明,求解器可能难以通过多个非线性层近似单位映射。利用残差学习重构,如果单位映射是最优的,则求解器可以简单地将多个非线性层的权重推向零,以接近单位映射。如果最优函数更接近单位映射而不是零映射,则求解器应更容易找到参考单位映射的扰动,而不是将函数学习为新函数。

对每几层展开残差学习,如图2,形式上可以构建一个块儿,定义如下:

y = F(x, {Wi}) + x

或者

y = F(x, {Wi}) + Ws·x

使用线性投影Ws匹配维度。网络结构如下图

更多内容请阅读原始文献

基于Pytorch的实现

1. resnet网络pytorch实现

import torch
import torch.nn as nnclass BasicBlock(nn.Module):expansion = 1def __init__(self, in_channels, out_channels, stride=1):super().__init__()#residual functionself.residual_function = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=(3,3), stride=stride, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels*BasicBlock.expansion, kernel_size=(3,3), padding=1, bias=False),nn.BatchNorm2d(out_channels*BasicBlock.expansion))#shortcutself.shortcut = nn.Sequential()# the shortcut output dimension is not the same with residual function# use 1*1 convolution to match the dimensionif stride != 1 or in_channels != BasicBlock.expansion * out_channels:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels*BasicBlock.expansion, kernel_size=(1,1), stride=stride, bias=False),nn.BatchNorm2d(out_channels*BasicBlock.expansion))def forward(self,x):return nn.ReLU(inplace=True)(self.residual_function(x)+self.shortcut(x)) #(inplace=True)修改原来的值,相当于地址传递class BottleNeck(nn.Module):expansion=4def __init__(self, in_channels, out_channels, stride=1):super().__init__()self.residual_function = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=(1,1), bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=(3,3), padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels*BottleNeck.expansion, kernel_size=(1,1), bias=False),nn.BatchNorm2d(out_channels*BottleNeck.expansion))self.shortcut = nn.Sequential()  #输出等于输入if stride != 1 or in_channels != BottleNeck.expansion * out_channels:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, kernel_size=(1,1), stride=stride, bias=False),nn.BatchNorm2d(out_channels*BottleNeck.expansion))#print("shortcut", BottleNeck.expansion * out_channels)def forward(self,x):#print(self.residual_function(x).size())#print(self.shortcut(x).size())return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))class ResNet(nn.Module):def __init__(self, block, num_block, num_classes=10):super().__init__()self.in_channels = 16self.conv1 = nn.Sequential(nn.Conv2d(3, 16, kernel_size=(3,3), padding=1, bias=False),nn.BatchNorm2d(16),nn.ReLU(inplace=True))self.conv2_x = self._make_layer(block, 16, num_block[0], 1)self.conv3_x = self._make_layer(block, 32, num_block[1], 2)self.conv4_x = self._make_layer(block, 64, num_block[2], 2)#self.conv5_x = self._make_layer(block, 512, num_block[3], 2)self.avg_pool = nn.AdaptiveAvgPool2d((1,1))self.fc = nn.Linear(64*block.expansion, num_classes)def _make_layer(self, block, out_channels, num_blocks, stride):strides = [stride] + [1]*(num_blocks-1) #[stride, 1, 1, 1, ......]layers = []for stride in strides:layers.append(block(self.in_channels, out_channels, stride))self.in_channels = out_channels*block.expansionreturn nn.Sequential(*layers) #通过nn.Sequential函数将列表通过非关键字参数的形式传入def forward(self, x):output = self.conv1(x)output = self.conv2_x(output)output = self.conv3_x(output)output = self.conv4_x(output)#output = self.conv5_x(output)output = self.avg_pool(output)output = output.view(output.size(0),-1)output = self.fc(output)return outputdef ResNet_cifar10():return ResNet(BottleNeck, [2, 2, 2]) #6n+2层

2. 训练函数

def train(model, criterion,  optimizer, epochs):start = time.time()best_acc = 0.0for epoch in range(epochs):start_epoch = time.time()print("Epoch {}/{}".format(epoch, epochs-1))print("-" * 16)model.train()running_loss = 0.0running_correct = 0.0for index, data in enumerate(trainloader):inputs, labels = datainputs = inputs.to(device)labels = labels.to(device)optimizer.zero_grad()outputs = model(inputs)_, preds = torch.max(outputs, 1) #返回每一行中最大值的那个元素,且返回其索引, preds是索引loss = criterion(outputs, labels) #每批量的平均损失loss.backward()optimizer.step()running_loss += loss.item()*inputs.size(0)running_correct += torch.sum(preds == labels.data)epoch_loss = running_loss / len(trainset)epoch_acc = running_correct / len(trainset)writer.add_scalar('loss', epoch_loss, epoch)writer.add_scalar('acc', epoch_acc, epoch)if epoch_acc > best_acc:best_acc = epoch_accprint('train Loss: {:.4f} Acc: {:.4f}'.format(epoch_loss, epoch_acc))time_epoch = time.time() - start_epochprint('This epoch waste {:.0f}m {:.0f}s'.format(time_epoch // 60, time_epoch % 60))print('##' * 12)time_elapsed = time.time() - startprint('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))print('Best val Acc: {:4f}'.format(best_acc))return model

3. 主要运行代码

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import torchvision
import torch.utils.data
import torchvision.transforms as transforms
import time
from tensorboardX import SummaryWriter
import MyResNetdef imshow(img):mean = torch.as_tensor([0.4914, 0.4822, 0.4465])std = torch.as_tensor([0.2023, 0.1994, 0.2010])mean = mean.view(-1, 1, 1)std = std.view(-1, 1, 1)#print(mean)img.mul_(std).add_(mean)npimg = img.numpy()plt.imshow((np.transpose(npimg, (1, 2, 0))*255).astype('uint8'))plt.show()if __name__ == '__main__':transform_train = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=False, transform=transform_train)trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,shuffle=True, num_workers=1)testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform_test)testloader = torch.utils.data.DataLoader(testset, batch_size=64,shuffle=False, num_workers=1)classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')# # get some random training images# dataiter = iter(trainloader)# images, labels = dataiter.next()# # show images# imshow(torchvision.utils.make_grid(images))print("训练集的长度:{}".format(len(trainset)))print("测试集的长度:{}".format(len(testset)))device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print(device)writer = SummaryWriter('D:/MY_pyCharm_project/pythonProject/resnetForCifar-100/runs')net = MyResNet.ResNet_cifar10()net.to(device)criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)model = train(net, criterion, optimizer, epochs=3)input = torch.rand(32, 3, 32, 32)  # 示例输入input = input.to(device)writer.add_graph(model, (input,))writer.close()

4. 程序介绍

我们使用cifar10数据集(从官方数据库中下载,第一次运行可设置download=True),利用resnet对cifar10数据完成分类,并使用tensorboard进行训练过程和模型的可视化,使用交叉熵损失作为损失函数,使用SGD优化器。学习率设置为0.1,动量为0.9,权重衰减设置为0.0001,最后进行三轮训练。各种超参数可以自由调整(上述代码中都是随意设置的,本着运行速度和测试代码的初衷)

我们没有使用文献 Deep Residual Learning for Image Recognition中的cifar10分类实验的网络结构,这里进行了简单的修改,如果有兴趣,可以参考原文。

运行环境torch.__version__= '1.10.0'

参考文献

Deep Residual Learning for Image Recognition

https://github.com/weiaicunzai/pytorch-cifar100

写在最后

欢迎大家学习交流,共同进步QQ:2634110636

ResNet 阅读学习和实现相关推荐

  1. 最值得阅读学习的 10 个 C 语言开源项目代码

    本文转载于: 最值得阅读学习的 10 个 C 语言开源项目代码 从扩展思路的角度来说,一个程序员应该好好读过这样一些代码: 一个操作系统内核 一个编译器(如:gcc,lua) 一个解释器(如:pyth ...

  2. PyTorch实战使用Resnet迁移学习

    PyTorch实战使用Resnet迁移学习 项目结构 项目任务 项目代码 网络模型测试 项目结构 数据集存放在flower_data文件夹 cat_to_name.json是makejson文件运行生 ...

  3. ResNet网络学习笔记。

    ResNet网络学习 看b站 霹雳吧啦Wz 的视频总结的学习笔记! 视频的地址 大佬的Github代码 1.ResNet详解 ResNet 网络是在2015年由微软实验室提出,斩获当年 ImageNe ...

  4. Java源码阅读学习后的浅析和感悟(JDK篇)(持续更新)

    目录 Java源码阅读学习后的浅析和感悟(JKD篇) - 为什么阅读源码 集合框架类 - 为什么会要引入集合 - 集合结构图(部分) ArrayList集合源码分析 - 扩容机制 - 关键方法解释(D ...

  5. ResNet的学习笔记~

    1 前言 今天在学习ResNet~ 一直在学习和研究ResNet,不过有些东西一直没有弄懂,觉得还是需要通过实验来验证一下, 今天在学习CS231n时,Yang助教又讲到ResNet,这里我们再来复习 ...

  6. ResNet 小白学习笔记

    写在前面 直接看论文对我这个小白来说太不友好了,几次放弃 /(ㄒoㄒ)/~~ .幸好找到了一个通俗易懂的视频:6.1 ResNet网络结构,BN以及迁移学习详解,以下笔记大部分基于视频内容,再补充了一 ...

  7. 十个最值得阅读学习的C开源项目代码

    为什么80%的码农都做不了架构师?>>>    1. Webbench Webbench是一个在linux下使用的非常简单的网站压测工具.它使用fork()模拟多个客户端同时访问我们 ...

  8. 转载_最值得阅读学习的10个C语言开源项目代码

    "源代码面前,了无秘密",阅读优秀代码无疑是开发人员得以窥见软件堂奥而登堂入室的捷径.本文选取10个C语言优秀开源项目的代码作为范本,分别给予点评,免去东搜西罗之苦,点赞!那么问题 ...

  9. 卷积神经网络之ResNet网络模型学习

    Deep Residual Learning for Image Recognition 微软亚洲研究院的何凯明等人 论文地址 https://arxiv.org/pdf/1512.03385v1.p ...

最新文章

  1. 《C++游戏编程入门(第4版)》——2.4 使用带else子句的if语句序列
  2. Windows 技术篇 - 如何查看cpu支持的指令集、型号、属性等详细信息,使用cpu-z工具查看处理器、内存、显卡、主板、缓存、SPD信息方法
  3. 使用QueueUserAPC线程注入,
  4. Phpstorm界面不停的indexing,不停的闪烁
  5. 以系统化视角反观产品运营,解读提升用户转化的“四部曲”
  6. php将权限写入session,PHP由session文件夹权限不够引起的报错
  7. 【MySQL】源码安装MySQL
  8. [转]将c#中datagridview中的数据导出到excel中
  9. java目录删除_java删除文件及目录
  10. 如何实现SSID白名单管控
  11. 答题小程序/刷题微信小程序/考试小程序2.0版本(新增代理,团购,题目导入,数据导出等功能,THINKPHP后台)
  12. HRM人力资源管理平台技术总结
  13. A级学科计算机技术,东南大学a类学科排名!附东大a类学科名单
  14. 点击验证码时候自动刷新功能
  15. 解决vue页面四周有白边的问题
  16. Android 实现图文混排
  17. 核心概念——节点/边/Combo——内置Combo——内置Combo总览
  18. mezzanine-一个功能强大且易于扩展性的Django框架构建的内容管理平台
  19. 狄利克雷分布公式_关于狄利克雷分布的理解
  20. 利用随机森林算法实现Bank风险预测

热门文章

  1. glew、glfw、glad、freeglut的教程与区别
  2. 别人的18岁,恐怕会碾压你的38岁
  3. CodeForces - 766E  (树形dp+二进制)
  4. Linux上 journal 可以删除吗?
  5. TextRank方法的优化——MMR(最大边界相关算法)
  6. 关于Symantec(赛门铁克)认证服务
  7. Github上更新自己Fork的代码
  8. 静态资源交替成功失败500
  9. buddypress主题_BuddyPress入门指南:提示和资源
  10. 浙江新2014挂历制作,供应温州挂历印刷公司