【pytorch torchvision源码解读系列—1】Alexnet
最近开始学习一个新的深度学习框架PyTorch。
框架中有一个非常重要且好用的包:torchvision,顾名思义这个包主要是关于计算机视觉cv的。这个包主要由3个子包组成,分别是:torchvision.datasets、torchvision.models、torchvision.transforms。
具体介绍可以参考官网:https://pytorch.org/docs/master/torchvision
具体代码可以参考github:https://github.com/pytorch/vision
torchvision.models这个包中包含alexnet、densenet、inception、resnet、squeezenet、vgg等常用经典的网络结构,并且提供了预训练模型,可以通过简单调用来读取网络结构和预训练模型。
今天我们来解读一下Alexnet的源码实现。如果对AlexNet不是很了解 可以查看这里的论文笔记https://blog.csdn.net/sinat_33487968/article/details/83543406
如何使用呢?
import torchvision
model = torchvision.models.Alexnet(pretrained=True)
这样就可以获得网络的结构了,pretrained参数的意思是是否预训练,如果为True就会从网上下载好已经训练参数的模型。改参数默认是False。
import torch.utils.model_zoo as model_zoo__all__ = ['AlexNet', 'alexnet']model_urls = {'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
}
首先是导入必要的库,其中model_zoo是和导入预训练模型相关的包,另外all变量定义了可以从外部import的函数名或类名。这也是前面为什么可以用torchvision.models.alexnet()来调用的原因。model_urls这个字典是预训练模型的下载地址。
接下来就是Alexnet这个类
class AlexNet(nn.Module):def __init__(self, num_classes=1000):super(AlexNet, self).__init__()self.feature = nn.Sequential(nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),nn.ReLU(inplace=True), # inplace为True,将会改变输入的数据 ,否则不会改变原输入,只会产生新的输出nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(64, 192, kernel_size=5, padding=2),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(192, 384, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(384, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),)self.classifer = nn.Sequential(nn.Dropout(),nn.Linear(256 * 6 * 6, 4096),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(4096, 4096),nn.ReLU(inplace=True),nn.Linear(4096, num_classes),)def forward(self, x):x = self.feature(x)x = x.view(x.size(0), 256 * 6 * 6) # reshapex = self.classifer(x)return x
AlexNet网络是通过AlexNet这个类实例化的。首先还是继承PyTorch中网络的基类:torch.nn.Module,其次主要的是重写初始化__init__和forward方法。在初始化__init__中主要是定义一些层的参数。forward方法中主要是定义数据在层之间的流动顺序,也就是层的连接顺序。基本上就是五层卷积加上三层全连接(不算relu和max max pooling)。注意到ReLU的inplace为True,将会改变输入的数据 ,否则不会改变原输入,只会产生新的输出。而 x = x.view(x.size(0), 256 * 6 * 6) 的意思是reshape卷积层得到的结果,为了匹配后面的全连接层。
具体结构可以参照下图:
最后呈现上源码
import torch.nn as nn
import torch.utils.model_zoo as model_zoo__all__ = ['Alexnet', 'alexnet']model_urls = {'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
}class AlexNet(nn.Module):def __init__(self, num_classes=1000):super(AlexNet, self).__init__()self.feature = nn.Sequential(nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),nn.ReLU(inplace=True), # inplace为True,将会改变输入的数据 ,否则不会改变原输入,只会产生新的输出nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(64, 192, kernel_size=5, padding=2),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(192, 384, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(384, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),)self.classifer = nn.Sequential(nn.Dropout(),nn.Linear(256 * 6 * 6, 4096),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(4096, 4096),nn.ReLU(inplace=True),nn.Linear(4096, num_classes),)def forward(self, x):x = self.feature(x)x = x.view(x.size(0), 256 * 6 * 6) # reshapex = self.classifer(x)return xdef alexnet(pretrained = False,**kwargs):r"""AlexNet model architecture from the"One werid trick..."<https://arxiv.org/abs/1404.5997>_papper.Args:pretrained(bool):if True,returns a model pre-trained on ImagetNet"""model = AlexNet(**kwargs)if pretrained:model.load_state_dict(model_zoo.load_url(model_urls['alexnet']))return modelif __name__ == '__main__':alexnet()
【pytorch torchvision源码解读系列—1】Alexnet相关推荐
- 【注意力机制集锦】Channel Attention通道注意力网络结构、源码解读系列一
Channel Attention网络结构.源码解读系列一 SE-Net.SK-Net与CBAM 1 SENet 原文链接:SENet原文 源码链接:SENet源码 Squeeze-and-Excit ...
- Alamofire源码解读系列(九)之响应封装(Response)
本篇主要带来Alamofire中Response的解读 前言 在每篇文章的前言部分,我都会把我认为的本篇最重要的内容提前讲一下.我更想同大家分享这些顶级框架在设计和编码层次究竟有哪些过人的地方?当然, ...
- Alamofire源码解读系列(五)之结果封装(Result)
本篇讲解Result的封装 前言 有时候,我们会根据现实中的事物来对程序中的某个业务关系进行抽象,这句话很难理解.在Alamofire中,使用Response来描述请求后的结果.我们都知道Alamof ...
- py-faster-rcnn源码解读系列
转载自: py-faster-rcnn源码解读系列(一)--train_faster_rcnn_alt_opt.py - sunyiyou9的博客 - 博客频道 - CSDN.NET http://b ...
- Hadoop源码解读系列目录
Hadoop源码解读系列 1.hadoop源码|common模块-configuration详解 2.hadoop源码|core模块-序列化与压缩详解 3.hadoop源码|core模块-远程调用与N ...
- Alamofire源码解读系列(十二)之请求(Request)
本篇是Alamofire中的请求抽象层的讲解 前言 在Alamofire中,围绕着Request,设计了很多额外的特性,这也恰恰表明,Request是所有请求的基础部分和发起点.这无疑给我们一个Req ...
- faster rcnn fpn_Faster-RCNN详解和torchvision源码解读(三):特征提取
我们使用ResNet-50-FPN提取特征 model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) ...
- Alamofire源码解读系列(七)之网络监控(NetworkReachabilityManager)
Alamofire源码解读系列(七)之网络监控(NetworkReachabilityManager) 本篇主要讲解iOS开发中的网络监控 前言 在开发中,有时候我们需要获取这些信息: 手机是否联网 ...
- spring源码解读系列(八):观察者模式--spring监听器详解
一.前言 在前面的文章spring源码解读系列(七)中,我们继续剖析了spring的核心refresh()方法中的registerBeanPostProcessors(beanFactory)(完成B ...
最新文章
- Java项目:诚途旅游系统(java+JSP+Spring+SSM+Mysql)
- html json 访问工程,SpringBoot:Web项目中如何优雅的同时处理Json和Html请求的异常...
- OpenCv Java Mat的基本使用-行列式计算(6)
- Linux学习笔记重新梳理20180702 之 yum软件包管理器
- 使用LINQ更新集合中的所有对象
- opencv 图像 抠图 算法_人工智能 | 不用绿幕也能实时抠图,商汤等提出只需单张图像、单个模型的新方法MODNet...
- 使用Github(基本概念实战操作)
- 新松机器人刻蚀机_中国最大机器人产业基地新松智慧园在沈阳启用
- 计算机基础远程教育答案,浙大远程教育2013年计算机作业答案-1-计算机基础知识题.docx...
- NYOJ259 - 茵茵的第一课
- 系统学习机器学习之神经网络(五) --ART
- ecshop快速购买
- autofs rhel7
- pipe()函数详解
- C语言:L1-078 吉老师的回归 (15 分)
- 字符编码常识及问题解析
- UI设计之什么是设计
- 为什么安装好mysql打不开_MySQL安装完成之后怎么启动? mysql安装完成后怎么
- web端用canvas把航拍图片实际场景渲染在高德卫星地图上面
- weblogic打补丁详细流程