ResNet模型代码解析

  • 1 ResNet 图解分析(论文)
    • 1.1 论文中的模型图、解释
      • 1.1.1 残差结构块
      • 1.1.2 残差结构模型——34层
      • 1.1.3 残差结构模型——多种类型
  • 2 ResNet-34 代码分析
    • 2.1 模型代码分析
      • 2.1.1 (BasicBlock)ResNet-34基本块
      • 2.2.2 (Bottleneck)ResNet-更多层基本块
      • 2.2.3 (ResNet)网络总模块
      • 2.2.4 网络多种架构模块
    • 2.2 训练代码分析
    • 2.3 预测代码分析

代码将在图像分类的数据集进行分析。

1 ResNet 图解分析(论文)

1.1 论文中的模型图、解释

1.1.1 残差结构块

————————图示

—————————解释

从上述图中可以看出:x为残差块的输入,然后复制成两部分,一部分输入到层(weight layer)之中,进行层间的运算(相当于将x输入到一个函数中做映射),结果为f(x);另一部分作为分支结构,输出还是原本的x,最后将分别两部分的输出进行叠加:f(x) + x,再通过激活函数。这便是整个残差块的基本结构。

1.1.2 残差结构模型——34层

————————图示

————————解释

对于批标准化(batch-normalization)以及激活函数(activation function)在这里先不做分析,在之后的代码中会有分析。

开头部分首先进行了一次卷积核为7×7,步长为2×2的卷积操作,然后进行了一次最大池化操作。

中间部分,resnet34分成了四块部分,每部分分别为3个残差块、4个残差块、6个残差块、3个残差块,逐一分析:
·
第一部分都是(卷积核大小3×3,卷积核个数64)卷积操作。
第二部分都是(卷积核大小3×3,卷积核个数128)卷积操作。
第三部分都是(卷积核大小3×3,卷积核个数256)卷积操作。
第四部分都是(卷积核大小3×3,卷积核个数512)卷积操作。
·
对于每一部分的第一个残差块的第一次卷积操作,它的步长为2,其余的都是1.
·
对于每一部分的第一个残差块的输入 (即每一条虚线部分) 来说,由于上一部分的通道数与本部分的通道数不一致,所以在其中隐含了利用1×1的卷积操作,形成downsample,增加通道数(详情看代码分析部分),使到本部分的输入的通道数与本部分通道数保持一致,这样才可以进行相同通道上的像素叠加。

最后部分,进行了平均池化,然后经过拉直层,最后进行一个全连接,输出分类概率。

1.1.3 残差结构模型——多种类型

————————图示

在18,34层的ResNet 网络中他们在每一大部分内部的通道数都不会发生变化,然后在50和101,152层中的通道数都会发生变化,在代码中的expansion这个参数设定为1还是4,就是因为在大部分中通道数是否发生变化。


2 ResNet-34 代码分析

2.1 模型代码分析

pytoch官方也给出了ResNet的代码,查询方法:

import torchvision.models.resnet
### 鼠标悬浮resnet字符上,然后按住ctrl + 鼠标左键 即可跳转至官方ResNet的代码

2.1.1 (BasicBlock)ResNet-34基本块

class BasicBlock(nn.Module):  ## ()内为继承nn的模型expansion = 1    ### 这个参数在resnet34层中并没有什么用处  这个参数是为了控制在一大部分中的通道数变化的def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):"""@param in_channel:  此块输入的通道数@param out_channel: 输出的通道数@param stride: 在第一个卷积层的步长@param downsample: 是否进行下采样@param kwargs: 其他参数(可变长参数)"""### 父类初始化super(BasicBlock, self).__init__()### 自定义操作赋值给变量self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=stride, padding=1, bias=False) #定义卷积层self.bn1 = nn.BatchNorm2d(out_channel) #定义归一化self.relu = nn.ReLU()# 定义激活函数self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, stride=1, padding=1, bias=False) # 定义卷积层self.bn2 = nn.BatchNorm2d(out_channel)# 定义归一化self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, stride=1, padding=1, bias=False) # 定义卷积层self.bn3 = nn.BatchNorm2d(out_channel)# 定义归一化self.downsample = downsample   ##定义下采样部分def forward(self, x):### 在这里就是构造残差块的基本结构identity = x ## 先将最开始的输入 进行赋值到identity  这一部分是为了进行恒等映射if self.downsample is not None:  # 如果downsample 不是空值(下采样)的话  #就在后方进行下采样层相应的操作,因为在上述分析模块部分已经说到,#在虚线部分,会因为通道数不一致,要进行下采样操作,使得通道数一致。identity = self.downsample(x)### 两部分 卷积批标准化 卷积批标准化out = self.conv1(x)out = self.bn1(out)out = self.relu(out)  ## 过一层激活层out = self.conv2(out)out = self.bn2(out)out = out + identity   ## 将此基本结构块的输入与本结构块的最后一层的输出进行叠加,#形成最终的输出out = self.relu(out) # 过激活函数return out  ## 返回此结构块的输出 也就是下一个残差块的基本结构的输入了

2.2.2 (Bottleneck)ResNet-更多层基本块

如果只看resnet-34可以不看这部分。

class Bottleneck(nn.Module):"""注意:原论文中,在虚线残差结构的主分支上,第一个1x1卷积层的步距是2,第二个3x3卷积层步距是1。但在pytorch官方实现过程中是第一个1x1卷积层的步距是1,第二个3x3卷积层步距是2,这么做的好处是能够在top1上提升大概0.5%的准确率。"""expansion = 4  ## 这个就是通道数变化的系数 def __init__(self, in_channel, out_channel, stride=1, downsample=None, groups=1, width_per_group=64):super(Bottleneck, self).__init__()## resnext50_32x4d 和resnext101_32x8d 会使用width = int(out_channel * (width_per_group / 64.)) * groups### 定义三个卷积过程self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width, kernel_size=1, stride=1, bias=False)  # squeeze channelsself.bn1 = nn.BatchNorm2d(width)# -----------------------------------------self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups, kernel_size=3, stride=stride, bias=False, padding=1)self.bn2 = nn.BatchNorm2d(width)# -----------------------------------------self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion, kernel_size=1, stride=1, bias=False)  # unsqueeze channelsself.bn3 = nn.BatchNorm2d(out_channel*self.expansion)self.relu = nn.ReLU(inplace=True)##下采样self.downsample = downsampledef forward(self, x):identity = xif self.downsample is not None:identity = self.downsample(x)## 基本块为三次卷积层的过程 并非跟resnet34一致out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)out += identityout = self.relu(out)return out ## 返回输出

2.2.3 (ResNet)网络总模块

class ResNet(nn.Module):  ##继承自nn.Module函数def __init__(self, block, blocks_num, num_classes=1000, include_top=True, groups=1, width_per_group=64):"""@param block:  传入实例化BasicBlock,就是上一部分代码的基本块@param blocks_num:  块的个数此为一个列表 列表长度为 resNet几大块  对应列表中的每一个数——> 即为每一部分的基本块的块数,例如resnet34 中的blocks_num = [3, 4, 6, 3]   可以看上述图中的resnet34的结构 本人画出的四大块@param num_classes: 几分类@param include_top: 判定条件是否采用适应性平均池化@param groups:@param width_per_group:"""super(ResNet, self).__init__() ## 进行赋值self.include_top = include_topself.in_channel = 64   ### 输入的通道数self.groups = groupsself.width_per_group = width_per_group### 最开始的一大层     先进行卷积核为7*7 卷积操作self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2, padding=3, bias=False)## 然后进行归一化self.bn1 = nn.BatchNorm2d(self.in_channel)## 然后过激活函数  增加非线性表达self.relu = nn.ReLU(inplace=True)## 然后经过最大池化层self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)#### 进行ResNet的四大部分[3, 4, 6, 3]# 构造每一部分函数:_make_layer :为本类的成员函数  在下方self.layer1 = self._make_layer(block, 64, blocks_num[0])self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)### 一个判定条件  为True则是会有自适应平均池化if self.include_top:self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)   ### 自适应平均池化  自适应去需要进行均值的数值self.fc = nn.Linear(512 * block.expansion, num_classes)   ## 全连接层### 遍历每一个模块进行模块的权重的初始化for m in self.modules():## 遍历所有的层if isinstance(m, nn.Conv2d):   ## m 是否为卷积的实例化nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')### 以何凯明大佬命名的初始化的函数### 此函数为构造每一大部分的函数def _make_layer(self, block, channel, block_num, stride=1):"""@param block: 基本块   resnet34层传入的是BasicBlock,更多层数的会传入瓶颈块 Bottleneck@param channel: 每一大部分的通道数@param block_num: 此部分的基本块的数量@param stride:  步长@return:"""downsample = None# 步长不为1   因为需要downsample进行对于不同残差块之间的统一   上一个残差块的输出与本残差块的输出宽高保持一致### 如果步长不为1 也就是每个大部分的第一个基本块的第一个卷积层,### 或者 输入通道与通道不匹配,也就是resnet 更多层会出现 一大部分中通道数发生变化### 会出现通道数不匹配 所以要进行下采样,统一通道数的if stride != 1 or self.in_channel != channel * block.expansion:###block.expansion = 1downsample = nn.Sequential(nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),  ## 进行卷积保持宽高一致nn.BatchNorm2d(channel * block.expansion)  ## 批归一化)#定义一个层列表layers = []## 添加基本块  先添加一个基本块  之后的用循环,因为第一个基本块,可能会出现下采样的情况layers.append(block(self.in_channel, channel, downsample=downsample, stride=stride, groups=self.groups, width_per_group=self.width_per_group))# 注意注意    只有在不同大部分的情况下才进行此操作  因为上一大部分跟下一大部分的channel通道不一样self.in_channel = channel * block.expansion    ## 下一层的输入等于本层的输出### 进行循环 通过小的基本块 构造一个大部分for _ in range(1, block_num):layers.append(block(self.in_channel, channel, groups=self.groups, width_per_group=self.width_per_group))## 返回本大部分return nn.Sequential(*layers)def forward(self, x):#开始一部分x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)## 中间四大部分x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)if self.include_top:## 判定x = self.avgpool(x)   ### 自适应平均池化x = torch.flatten(x, 1)  ## 拉直层x = self.fc(x) # 全连接层return x  ##返回输出

2.2.4 网络多种架构模块

## 对于不同层的resnet网络 ,所传入的参数设定
def resnet34(num_classes=1000, include_top=True):return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)def resnet50(num_classes=1000, include_top=True):return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)def resnet101(num_classes=1000, include_top=True):return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)def resnext50_32x4d(num_classes=1000, include_top=True):groups = 32width_per_group = 4return ResNet(Bottleneck, [3, 4, 6, 3],num_classes=num_classes,include_top=include_top,groups=groups,width_per_group=width_per_group)def resnext101_32x8d(num_classes=1000, include_top=True):groups = 32width_per_group = 8return ResNet(Bottleneck, [3, 4, 23, 3],num_classes=num_classes,include_top=include_top,groups=groups,width_per_group=width_per_group)

2.2 训练代码分析

import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm
from model import resnet34
import random
import numpy as np
import pandas as pdsave_path = "./resnet34.pth"   ## 模型保存路径
epochs = 30    #训练轮数### 本部分是为了进行 实验复现
seed = 3
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = Falsedef main():### 查看可用gpu 设备device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("using {} device.".format(device))# 定义数据转换格式   Image->tensordata_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}# 寻找数据集文件夹  将路径进行拼接data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root pathimage_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set pathassert os.path.exists(image_path), "{} path does not exist.".format(image_path)## 读取文件夹中的图片文件train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform=data_transform["train"])train_num = len(train_dataset)    #3306 张图片# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}flower_list = train_dataset.class_to_idx   ##定义列表 类与所对应的数字分类的字典# 将key value 进行调换顺序cla_dict = dict((val, key) for key, val in flower_list.items())# write dict into json file   将字典写入json文件json_str = json.dumps(cla_dict, indent=4)with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 128#### 确定数据加载器的进程个数nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint('Using {} dataloader workers every process'.format(nw))## 训练的加载器train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=nw)## 读取测试集文件夹validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),transform=data_transform["val"])val_num = len(validate_dataset)## 测试集数据加载器validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=False,num_workers=nw)## 输出用于训练以及测试的数据集数据的个数print("using {} images for training, {} images for validation.".format(train_num, val_num))# 实例化网络模型net = resnet34()# load pretrain weights# change fc layer structure##### 重写网络中的全连接层部分  将输出改写为5## 其实在最开始num_classes 赋值为5就可以in_channel = net.fc.in_featuresnet.fc = nn.Linear(in_channel, 5)### 如果有模型的话就进行加载模型if os.path.exists(save_path):print("---Loading_Model---")net.load_state_dict(torch.load(save_path, map_location=device))net.to(device)# define loss functionloss_function = nn.CrossEntropyLoss()# construct an optimizerparams = [p for p in net.parameters() if p.requires_grad]optimizer = optim.Adam(params, lr=0.0001)best_acc = 0.0   ### 记录最好的准确度train_steps = len(train_loader)   ###  将数据分为多少批次    3306/batch_sizeprint("train_steps", train_steps)rows = []ones = []for epoch in range(epochs):# trainnet.train()running_loss = 0.0   ## 初始化在一个epoch中的损失值train_bar = tqdm(train_loader)   ## 进度条库for step, data in enumerate(train_bar):   ## 用数据加载器读取数据images, labels = data   ## 特征目标值optimizer.zero_grad()   ## 梯度归零logits = net(images.to(device))    ##调用网络  将数据喂入到网络中loss = loss_function(logits, labels.to(device))    ## 求损失loss.backward()  ### 反向传播optimizer.step()  ## 损失优化# print statisticsrunning_loss += loss.item()   ## 累加损失train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss)# validatenet.eval()### 测试部分acc = 0.0  # accumulate accurate number / epoch    ## 定义准确度 每一轮训练要归零with torch.no_grad():   ## 不产生梯度val_bar = tqdm(validate_loader)for val_data in val_bar:val_images, val_labels = val_dataoutputs = net(val_images.to(device))# loss = loss_function(outputs, test_labels)predict_y = torch.max(outputs, dim=1)[1]acc += torch.eq(predict_y, val_labels.to(device)).sum().item()val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,epochs)val_accurate = acc / val_numones.append(running_loss / train_steps)ones.append(val_accurate)rows.append(ones)ones=[]print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %(epoch + 1, running_loss / train_steps, val_accurate))# 这一轮训练的比之前最好的准确度还高的话就保存if val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('Finished Training')print(rows)if __name__ == '__main__':main()

2.3 预测代码分析

## 预测代码类似训练代码 调用模型就可以了,然后传入自己的图片进行分类
import os
import jsonimport torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from model import resnet34
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'def main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")data_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# load imageimg_path = "../tulip.jpg"assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)img = Image.open(img_path)plt.imshow(img)# [N, C, H, W]img = data_transform(img)# expand batch dimensionimg = torch.unsqueeze(img, dim=0)# read class_indictjson_path = './class_indices.json'assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)json_file = open(json_path, "r")class_indict = json.load(json_file)# create modelmodel = resnet34(num_classes=5).to(device)# load model weightsweights_path = "./resNet34.pth"assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)model.load_state_dict(torch.load(weights_path, map_location=device))# predictionmodel.eval()with torch.no_grad():# predict classoutput = torch.squeeze(model(img.to(device))).cpu()predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy()print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)], predict[predict_cla].numpy())plt.title(print_res)print(print_res)plt.show()if __name__ == '__main__':main()

经典网络模型ResNet(残差网络)相关推荐

  1. ResNet残差网络及变体详解(符代码实现)

    本文通过分析深度网络模型的缺点引出ResNet残差网络,并介绍了几种变体,最后用代码实现ResNet18. 文章目录 前言 模型退化 残差结构 ResNet网络结构 Pre Activation Re ...

  2. (pytorch-深度学习系列)ResNet残差网络的理解-学习笔记

    ResNet残差网络的理解 ResNet伴随文章 Deep Residual Learning for Image Recognition 诞生,该文章是MSRA何凯明团队在2015年ImageNet ...

  3. 【五一创作】使用Resnet残差网络对图像进行分类(猫十二分类,模型定义、训练、保存、预测)(一)

    使用Resnet残差网络对图像进行分类 (猫十二分类,模型定义.训练.保存.预测)(一) 目录 一.项目简介 二.环境说明 1.安装库 2.导入需要的库 三.分类过程 (1).解压数据集 (2).相关 ...

  4. ResNet 残差网络、残差块

    在深度学习中,为了增强模型的学习能力,网络的层数会不断的加深,于此同时,也伴随着一些比较棘手的问题,主要包括: ①模型复杂度上升,网络训练困难 ②出现梯度消失/梯度爆炸问题 ③网络退化,即增加层数并不 ...

  5. ResNet残差网络Pytorch实现——对花的种类进行训练

    ResNet残差网络Pytorch实现--对花的种类进行训练 上一篇:[结合各个残差块] ✌✌✌✌ [目录] ✌✌✌✌ 下一篇:[对花的种类进行单数据预测] 大学生一枚,最近在学习神经网络,写这篇文章 ...

  6. 目标检测学习笔记2——ResNet残差网络学习、ResNet论文解读

    ResNet残差网络学习.ResNet论文解读 一.前言 为什么会提出ResNet? 什么是网络退化现象? 那网络退化现象是什么造成的呢? ResNet要如何解决退化问题? 二.残差模块 三.残差模块 ...

  7. ResNet残差网络

    (二十七)通俗易懂理解--Resnet残差网络 - 梦里寻梦的文章 - 知乎 https://zhuanlan.zhihu.com/p/67860570

  8. Halcon 深度学习自定义网络模型-ResNet通用网络产生器

    Halcon 深度学习自定义网络模型-ResNet通用网络产生器 备注: 版本要求:halcon21.05++ Python下的ResNet网络模型源码: import torch import to ...

  9. 【五一创作】使用Resnet残差网络对图像进行分类(猫十二分类,模型定义、训练、保存、预测)(二)

    使用Resnet残差网络对图像进行分类 (猫十二分类,模型定义.训练.保存.预测)(二) 目录 (6).数据集划分 (7).训练集增强 (8).装载数据集 (9).初始化模型 (10).模型训练 (1 ...

最新文章

  1. 论文阅读:FFDNet:Toward a Fast and Flexible Solution for CNN based Image Denoising
  2. angular2 安装
  3. python 各层级目录下的import方法
  4. uoj#38. 【清华集训2014】奇数国(线段树+数论)
  5. spring cloud分布式整合zipkin的链路跟踪
  6. 奇异值分解与低秩矩阵近似
  7. 白鹭引擎生成html,初识Egret白鹭引擎 之 创建舞台
  8. php修改学生信息代码_PHP连接MySQL数据库添加图书功能
  9. editplus怎么在前后插入字符
  10. python怎么计算指数_如何在Python中使用SciPy计算值和指数值的立方根?
  11. 初学C++遇到的引用头文件问题
  12. winform窗体——布局方式
  13. 固态硬盘 游戏测试软件,TxBENCH(SSD固态硬盘检测工具)
  14. autocad2014点击保存闪退_autocad2014启动闪退 AutoCAD启动时闪退怎么办
  15. 【Android】Gallery实现选中图片变大,两侧没选中图片变小
  16. mac parallels desk 网络初始化失败
  17. JAVA计算机毕业设计校园闲置物品信息管理系统Mybatis+源码+数据库+lw文档+系统+调试部署
  18. python英语词汇读音_40行Python代码区分英语单词和汉语拼音
  19. 电脑开机后黑屏的解决办法
  20. Neutron OVS-DVR

热门文章

  1. 2013年最损的话,搞笑得经典,快进…
  2. java毕业设计创达内部管理系统Mybatis+系统+数据库+调试部署
  3. 创建型模式:原型模式
  4. ESP8266-Arduino编程实例-HC-SR04超声波传感器驱动
  5. iphone8投屏电脑 苹果投屏电视的方法
  6. 计算机毕业设计(3)python毕设作品之小说电子书阅读系统
  7. js资源按需加载和预加载
  8. 制作Centos7自动安装镜像(四)
  9. 2.5-2.7 1×1 卷积 Inception 吴恩达 第四门课 卷积神经网络 第二周 深度卷积网络
  10. js中把字符串分割为数组,把数组转为字符串