本节将介绍利用CNN进行手写体识别
首先呢,我们需要下载数据来进行训练。下载的代码如下:
注意:数据集下载一次就好,

DOWNLOAD_MNIST = Truetrain_data=torchvision.datasets.MNIST(#下载数据的代码root='./mnist',train=True,transform=torchvision.transforms.ToTensor(),  #(网上数据改为tensor),0-1之间,并复制到train_data中download=DOWNLOAD_MNIST#没有下载就=true,下载了就用false)

然后我们简单介绍一下卷及神经网络CNN的结构:

整体流程是:卷积(Conv2d) -> 激励函数(ReLU) -> 池化, 向下采样 (MaxPooling) -> 再卷积(Conv2d) -> 再激励函数(ReLU) -> 再池化, 向下采样 (MaxPooling) -> 展平多维的卷积成的特征图 -> 接入全连接层 (Linear) -> 输出。
全部代码如下:(给出了大部分的详细注释)

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as pltEPOCH=1  #训练整批数据的次数
BATCH_SIZE = 50#批训练的数据个数
LR = 0.001     # 学习率
DOWNLOAD_MNIST = False  train_data=torchvision.datasets.MNIST(#下载数据的代码root='./mnist',train=True,transform=torchvision.transforms.ToTensor(),  #(网上数据改为tensor),0-1之间,并复制到train_data中download=DOWNLOAD_MNIST#没有下载就=true,下载了就用false)#plt测试一下下载的照片
# print(train_data.train_data.size())
# print(train_data.train_labels.size())
# plt.imshow(train_data.train_data[0].numpy(),cmap='gray')#呈现出第一张图片
# plt.title('%i'%train_data.train_labels[0])
# plt.show()#批训练(50,1,28*28)
train_loader=Data.DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True)#测试集,train说明不是traindata,是testdata
test_data=torchvision.datasets.MNIST(root='./mnist/',train=False)# 为了节约时间, 我们测试时只测试前2000个
test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1),volatile=True).type(torch.FloatTensor)[:2000]/255.
test_y = test_data.test_labels[:2000]#建立CNN网络
class CNN(nn.Module):def __init__(self):super(CNN,self).__init__()self.conv1=nn.Sequential(#卷积层1,包括以下三个内容nn.Conv2d(#卷基层,过滤器。in_channels=1,#图片的高度、层数。(因为输入的是二维图片,所以高度是1)out_channels=16,#16个filter的个数,同时进行扫描。输出的高度kernel_size=5,#kernel的宽、高都是5,5*5的扫描区域stride=1,#步长padding=2,#像素旁边一圈加上0的数据#如果 stride=1,padding=(kernel_size-1)/2=(5-1)/2=2),nn.ReLU(),#激活函数,加了一层卷积层nn.MaxPool2d(kernel_size=2),#池化层,筛选重要的部分)self.conv2=nn.Sequential(#卷积层2nn.Conv2d(16,32,5,1,2),nn.ReLU(),#激活函数,加了一层卷积层nn.MaxPool2d(2),#池化层,筛选重要的部分)self.out=nn.Linear(32*7*7,10)#def forward(self,x):#展平的过程x=self.conv1(x)x=self.conv2(x)            #(batch,32,7,7)x=x.view(x.size(0),-1)     #(batch,32*7*7)output=self.out(x)return outputcnn=CNN()optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)   # optimize all cnn parameters
loss_func = nn.CrossEntropyLoss() #训练过程
for epoch in range(EPOCH):for step, (x, y) in enumerate(train_loader):   # 分配 batch data, normalize x when iterate train_loaderb_x=Variable(x)b_y=Variable(y)output = cnn(b_x)               # cnn outputloss = loss_func(output, b_y)   # cross entropy lossoptimizer.zero_grad()           # clear gradients for this training steploss.backward()                 # backpropagation, compute gradientsoptimizer.step()                # apply gradientsif step%50==0:test_output=cnn(test_x)pred_y=torch.max(test_output,1)[1].data.squeeze()accuracy=sum(pred_y==test_y)/test_y.size(0)#最后取10组数据检查一下预测值到底对不对
test_output = cnn(test_x[:10])
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
print(pred_y, 'prediction number')
print(test_y[:10].numpy(), 'real number')

PyTorch学习(八)CNN手写体识别相关推荐

  1. 机器学习|卷积神经网络(CNN) 手写体识别 (MNIST)入门

    人工智能,机器学习,监督学习,神经网络,无论哪一个都是非常大的话题,都覆盖到可能就成一本书了,所以这篇文档只会包含在 RT-Thread 物联网操作系统,上面加载 MNIST 手写体识别模型相关的部分 ...

  2. 【深度学习框架】|PyTorch|完成一个手写体识别任务

  3. [Python人工智能] 六.TensorFlow实现分类学习及MNIST手写体识别案例

    从本专栏开始,作者正式开始研究Python深度学习.神经网络及人工智能相关知识.前一篇文章讲解了Tensorboard可视化的基本用法,并绘制整个神经网络及训练.学习的参数变化情况:本篇文章将通过Te ...

  4. python-机器学习-手写数字识别

    机器学习简单的来说,分为监督式学习和无监督式学习: 对于监督式学习就是需要人为的来告诉计算机这是什么,需要我们给他一个标签(答案). 无监督式学习就是不需要我们给出标签(答案). 图像识别(Image ...

  5. tensorflow学习笔记(七):CNN手写体(MNIST)识别

    文章目录 一.CNN简介 二.主要函数 三.CNN的手写体识别 1.MNIST数据集简介 2.网络描述 3.项目实战 一.CNN简介 一般的卷积神经网络由以下几个层组成:卷积层,池化层,非线性激活函数 ...

  6. 基于CNn的MINIST手写体识别

    深度学习的上机作业: 基于CNN卷积神经网络的MINIST手写体识别 版本:python-3.9,tensorflow-2.9 目录 MINIST数据集 训练CNN卷积神经网络 使用训练好的模型进行预 ...

  7. [深度学习-实践]BP神经网络的Helloworld(手写体识别和Fashion_mnist)

    前言 原理部分请看这里 [深度学习-原理]BP神经网络 Tensorflow2 实现一个简单的识别衣服的例子 数据集Fashion_mnist, 此数据集包含10类型的衣服 ('T-shirt/top ...

  8. python神经网络案例——CNN卷积神经网络实现mnist手写体识别

    分享一个朋友的人工智能教程.零基础!通俗易懂!风趣幽默!还带黄段子!大家可以看看是否对自己有帮助:点击打开 全栈工程师开发手册 (作者:栾鹏) python教程全解 CNN卷积神经网络的理论教程参考 ...

  9. 【从线性回归到 卷积神经网络CNN 循环神经网络RNN Pytorch 学习笔记 目录整合 源码解读 B站刘二大人 绪论(0/10)】

    深度学习 Pytorch 学习笔记 目录整合 数学推导与源码详解 B站刘二大人 目录传送门: 线性模型 Linear-Model 数学原理分析以及源码详解 深度学习 Pytorch笔记 B站刘二大人( ...

最新文章

  1. C语言结构体例子 (一)
  2. java在线学习系统源码_Java在线考试系统源码
  3. MongoDB Wiredtiger存储引擎实现原理
  4. java 运行时路径_如何在运行时检查当前Java类路径(重复)
  5. 正方形与圆的爱恨纠缠...
  6. hibernate脏数据_Hibernate性能提示:脏收集效果
  7. Eclipse: select at least one project
  8. IOS 创建简单表视图
  9. 如何动态在maven插件中加载项目及第三方类
  10. 【ElasticSearch】ElasticSearch 7.x 默认不在支持指定索引类型 Failed to parse mapping [_doc]: Root mapping definitio
  11. 通过源码理解反射与注解是什么东西?
  12. css清除浮动的几种方法_清除浮动的几种方法
  13. cwrsync从linux同步文件数据到windows
  14. unity资源商店出现“抱歉,此链接不再有效”怎么办
  15. 逆向分析工具IDA与开源工具Ghidra、Cutter对比测评
  16. managed DLL 和 normal DLL
  17. 主语从句、宾语从句、表语从句、同位语从句
  18. 【2019春招】平安科技开发实习生面经
  19. pe修复linux驱动,【CTF习题】BrokenDrivers(驱动修复及内核调试)
  20. 在线网页版鸡乐盒html源码

热门文章

  1. 如何将功能测试用例转为自动化脚本?
  2. 哈希树 (HashTree)
  3. python 折线图 百分比_Excel柱状图折线图组合怎么做 Excel百分比趋势图制作教程...
  4. JDK1.0到12各版本新特性
  5. footer.php置底,详解CSS五种方式实现Footer置底
  6. 第四章 HTML5 新增与修改的标签 <header> <footer> <section>
  7. Android案例(sd卡存储)
  8. 《我的眼睛--图灵识别》第三章:基础:颜色识别
  9. 不同VLAN之间互相通信
  10. 计算机应用基础试题省开8207,江苏省2015年“专转本”计算机应用基础统一考试试题.doc...