学习笔记,仅供参考,有错必纠


文章目录

  • 理论
    • 卷积神经网络CNN
      • 局部感受野和权值共享
      • 卷积计算
      • 池化Pooling
      • Padding
    • LeNET-5
  • 代码
    • 初始设置
    • 导包
    • 载入数据
    • 模型

理论

卷积神经网络CNN

卷积神经网络是近年发展起来,并广泛应用于图像处理,NLP等领域的一种多层神经网络。

局部感受野和权值共享

CNN通过局部感受野和权值共享减少了神经网络需要训练的参数个数,从而解决了传统BP权值太多,计算量太大,需要大量样本进行训练的问题.

卷积计算

卷积核也叫滤波器,不同的卷积核 对 同样的图片做卷积之后会提取出不同的信息. 以下图的卷积核为例,我们可以对示例Image进行卷积操作.


需要注意的是,卷积核里的参数不是人为设定的,而是算法优化得到的.

池化Pooling

Pooling常用的三种方式:

  • max-pooling
  • mean-pooling
  • stochastic pooling

Padding

  • SAME PADDING

给平面外部补0,卷积窗口采样后可能会得到一个跟原来大小相同的平面.

  • VALID PADDING

不会超出平面外部,卷积窗口采样后得到一个比原来平面小的平面。

LeNET-5

LeNET-5是最早的卷积神经网络之一. 下图为LeNET-5的网络结构.

我们可以看到通过对第3层进行卷积后,第4层得到了16幅图. 那么第4层的16幅图是如何计算的呢,操作如下图所示.

代码

初始设置

# 支持多行输出
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all' #默认为'last'

导包

# 导入常用的包
import numpy as np
from torch import nn,optim
from torch.autograd import Variable
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch

载入数据

# 载入数据
train_dataset = datasets.MNIST(root = './data/', # 载入的数据存放的位置train = True, # 载入训练集数据transform = transforms.ToTensor(), # 将载入进来的数据变成Tensordownload = True) # 是否下载数据
test_dataset = datasets.MNIST(root = './data/', # 载入的数据存放的位置train = False, # 载入测试集数据transform = transforms.ToTensor(), # 将载入进来的数据变成Tensordownload = True) # 是否下载数据
# 批次大小
batch_size = 64# 装载训练集
train_loader = DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)# 装载训练集
test_loader = DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=True)

模型

这里我们使用具有多层网络结构的模型,并加入Dropout操作.

# 定义网络结构
class Net(nn.Module):def __init__(self):super(Net, self).__init__()# 定义卷积层和池化# in_channels:int, 因为是黑白图片,所以输入通道设置为1,如果为彩色图像则这里为3# out_channels:int, 这里的输出通道数也为生成的特征图的数量,这里我们设置为32# kernel_size:int, 卷积核大小,我们设置为5# stride=1, 步长我们设置为1# padding=0, 我们设置padding为2,也就是在图片的外围补2圈0,这里我们要按照自己的需求自己计算# 如果想要卷积后的大小和原始图像大小相同,则卷积核大小为3*3则填充1圈0,5*5填充2圈,7*7填充3圈.# 因为卷积不是非线性操作,所以我们在卷积后增加非线性激活函数nn.ReLU()# 在卷积后,我们增加一个2*2的池化操作self.conv1 = nn.Sequential(nn.Conv2d(1, 32, 5, 1, 2), nn.ReLU(), nn.MaxPool2d(2, 2))# 再定义一个卷积和池化self.conv2 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2), nn.ReLU(), nn.MaxPool2d(2, 2))# 全连接# 全连接的输入为64个大小为(7*7)的特征图# 输出为1000self.fc1 = nn.Sequential(nn.Linear(64*7*7, 1000), nn.Dropout(p = 0.4), nn.ReLU())# 全连接self.fc2 = nn.Sequential(nn.Linear(1000, 10),nn.Softmax(dim = 1))def forward(self, x):# ([64, 1, 28, 28])# 卷积要求的数据格式就是4维的([图片数量, 图片通道数, 图片维度1, 图片维度2])x = self.conv1(x)x = self.conv2(x)# 进入全连接层时,需要reshape# ([64, 64, 7, 7]) -> ([64, 64*7*7])x = x.view(x.size()[0], -1)x = self.fc1(x)x = self.fc2(x)return x
LR = 0.0003
# 定义模型
model = Net()
# 定义代价函数为交叉熵代价函数
mse_loss = nn.CrossEntropyLoss()
# 定义优化器Adam
optimizer = optim.Adam(model.parameters(), LR)

在自定义训练和测试函数中,我们分别增加两个方法,model.train()model.eval() ,这model.train()方法可以使训练集中的Dropout在训练模型时发挥作用,而model.eval()则可以使模型在测试过程中不工作.

def train():model.train()for i,data in enumerate(train_loader):# 获得一个批次的数据和标签inputs, labels = data# 获得模型预测结果(64,10)out = model(inputs)# 计算loss,交叉熵代价函数out(batch,C), labels(batch)loss = mse_loss(out, labels)# 梯度清0optimizer.zero_grad()# 计算梯度loss.backward()# 修改权值optimizer.step()def test():model.eval()# 计算训练集准确率correct = 0for i,data in enumerate(train_loader):# 获得一个批次的数据和标签inputs, labels = data# 获得模型预测结果(64,10)out = model(inputs)# 获得最大值,以及最大值所在的位置_, predicted = torch.max(out, 1)# 预测正确的数量correct += (predicted == labels).sum()print("Train acc:{0}".format(correct.item()/len(train_dataset)))# 计算测试集准确率correct = 0for i,data in enumerate(test_loader):# 获得一个批次的数据和标签inputs, labels = data# 获得模型预测结果(64,10)out = model(inputs)# 获得最大值,以及最大值所在的位置_, predicted = torch.max(out, 1)# 预测正确的数量correct += (predicted == labels).sum()print("Test acc:{0}".format(correct.item()/len(test_dataset)))
for epoch in range(5):print('epoch:',epoch)train()test()
epoch: 0
Train acc:0.9728166666666667
Test acc:0.9755
epoch: 1
Train acc:0.9827666666666667
Test acc:0.983
epoch: 2
Train acc:0.9863
Test acc:0.9863
epoch: 3
Train acc:0.98665
Test acc:0.9842
epoch: 4
Train acc:0.99075
Test acc:0.9896

PyTorch基础(part7)--CNN相关推荐

  1. PyTorch学习笔记(四):PyTorch基础实战

    PyTorch实战:以FashionMNIST时装分类为例: 往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本 ...

  2. 基于pytorch使用实现CNN 如何使用pytorch构建CNN卷积神经网络

    基于pytorch使用实现CNN 如何使用pytorch构建CNN卷积神经网络 所用工具 文件结构: 数据: 代码: 结果: 改进思路 拓展 本文是一个基于pytorch使用CNN在生物信息学上进行位 ...

  3. 深入浅出Pytorch:02 PyTorch基础知识

    深入浅出Pytorch 02 PyTorch基础知识 内容属性:深度学习(实践)专题 航路开辟者:李嘉骐.牛志康.刘洋.陈安东 领航员:叶志雄 航海士:李嘉骐.牛志康.刘洋.陈安东 开源内容:http ...

  4. 第02章 PyTorch基础知识

    文章目录 第02章 Pytorch基础知识 2.1 张量 2.2 自动求导 2.3 并行计算简介 2.3.1 为什么要做并行计算 2.3.2 CUDA是个啥 2.3.3 做并行的方法 补充:通过股票数 ...

  5. 深度学习之Pytorch基础教程!

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:李祖贤,Datawhale高校群成员,深圳大学 随着深度学习的发展 ...

  6. python cnn_使用python中pytorch库实现cnn对mnist的识别

    使用python中pytorch库实现cnn对mnist的识别 1 环境:Anaconda3 64bit https://www.anaconda.com/download/ 2 环境:pycharm ...

  7. 【深度学习】基础知识--CNN:图像分类(上)

    作者信息: 华校专,曾任阿里巴巴资深算法工程师.智易科技首席算法研究员,现任腾讯高级研究员,<Python 大战机器学习>的作者. 编者按: 算法工程师必备系列更新啦!继上次推出了算法工程 ...

  8. 【深度学习】深度学习之Pytorch基础教程!

    作者:李祖贤,Datawhale高校群成员,深圳大学 随着深度学习的发展,深度学习框架开始大量的出现.尤其是近两年,Google.Facebook.Microsoft等巨头都围绕深度学习重点投资了一系 ...

  9. PyTorch基础(part5)--交叉熵

    学习笔记,仅供参考,有错必纠 文章目录 原理 代码 初始设置 导包 载入数据 模型 原理 交叉熵(Cross-Entropy) Loss=−(t∗ln⁡y+(1−t)ln⁡(1−y))Loss =-( ...

最新文章

  1. 数据驱动的云托管服务最佳范式
  2. FPGA之道(25)VHDL数据类型转换函数与数据对象的属性
  3. oracle.jobs中failures,Oracle job详解
  4. 将excel的数据导入到mysql数据表
  5. 云炬Android开发笔记 使用新版本Android studio快速Build低版本项目的仓库代码(标红部分)
  6. Ubuntu瘦身与扩容运动
  7. 原则 principles
  8. npm ERR! code E404 npm ERR! 404 Not Found - GET https://registry.npmjs.com/@mlamp%2fuser-info-dropdo
  9. python 编程服务_Python编写Windows Service服务程序
  10. VC++多线程工作笔记0007---线程间同步机制2
  11. 解决首次在eclipse中使用maven构建hadoop等项目时报Missing artifact sun.jdk:tools:jar:1.5.0的问题...
  12. 全向移动机器人参数校准对比及流程分析
  13. 解决servlet中post请求和get请求中文乱码现象
  14. WebGIS开发快速入门
  15. PDF如何导出成图片,操作教程
  16. javascript判断文本语言类型
  17. Let‘s Go Rust 系列之定时器 Ticker Timer
  18. 蓄水池采样算法的python实现_常用算法-蓄水池抽样算法
  19. CMD查看局域网在线IP
  20. 怎么把英文翻译成中文?手机中英翻译的简单方法

热门文章

  1. Wine cannot find the ncurses library (libncurses.so.5)
  2. gensim出现segmentation Fault解决方案
  3. IntelliJ IDEA内存优化最佳实践(转)
  4. debian下面的apt-fast安装
  5. php从大到小排列数字,php输入几个数从大到小排序
  6. Android学习资源网站
  7. 【洛谷4001】 [ICPC-Beijing 2006]狼抓兔子(最小割)
  8. 请按正确方法给UPS电源充电
  9. Win7+Ubuntu双系统结构下,Ubuntu克隆至新硬盘,启动成功
  10. devexpress PivotGrid Grand Total