手写字体识别MNIST

1.准备工作

可以看这个老师的视频进行学习,讲解的非常仔细:视频学习

2.项目代码

2.1 导入模块

# 1.加载相关库
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets,transforms

2.2 定义超参数
每批处理数据为16,训练轮数为10

# 2.定义超参数
BATCH_SIZE=16  #每批处理的数据
DEVICE=torch.device("cuda" if torch.cuda.is_available() else "cpu")  #是否用GPU还是CPU训练
EPOCHS=10  #训练数据集的轮次

2.3 构建pipeline,对图像做处理

# 3.构建pipeline,对图像做处理
pipeline=transforms.Compose([transforms.ToTensor(),#将图片转换成tensortransforms.Normalize((0.1307,),(0.3081))  #正则化:降低模型复杂度
])

2.4 下载、加载数据集(需要联网)
如图所示:

# 4.下载、加载数据集
from torch.utils.data import DataLoader# 下载数据集
train_set=datasets.MNIST("data",train=True,download=True,transform=pipeline)test_set=datasets.MNIST("data",train=False,download=True,transform=pipeline)# 加载数据
train_loader=DataLoader(train_set,batch_size=BATCH_SIZE,shuffle=True)test_loader=DataLoader(test_set,batch_size=BATCH_SIZE,shuffle=True)

2.5 构建网络模型(两层)

# 5.构建网络模型
class Digit(nn.Module):def __init__(self):super().__init__()self.conv1=nn.Conv2d(1,10,5)  #1:灰度图片的通道 10:输出通道  5:kernelself.conv2=nn.Conv2d(10,20,3)  #10:输入通道  20:输出通道  3:kernelself.fc1=nn.Linear(20*10*10,500)   #20*10*10:输入通道 500:输出通道self.fc2=nn.Linear(500,10)  #输入通道 10:输出通道def forward(self,x):input_size=x.size(0)  #atch_size*1*28*28 只取batch_sizex=self.conv1(x)  #输入:batch*1*28*28  输出:batch*10*24*24 (28-5+1=24)x=F.relu(x)  #保持shape不变,输出:batch*10*24*24x=F.max_pool2d(x,2,2)  #输入:batch*10*24*24  输出:batch*10*12*12x=self.conv2(x)  #输入:batch*10*12*12 输出:batch*20*10*10 (12-3+1=10)x=F.relu(x)x=x.view(input_size,-1)  #拉平,-1 自动计算维度  20*10*10=2000x=self.fc1(x)  #输入:batch*2000  输出:batch*500x=F.relu(x)  #保持shape不变x=self.fc2(x)  #输入:batch*500 输出:batch*10output=F.log_softmax(x,dim=1)  #计算分类后,每个数字的概率值,那个数字的概率最大那么就输出数字几return  output

2.6 定义优化器,训练方法和测试方法

# 6.定义优化器
model=Digit().to(DEVICE)optimizer=optim.Adam(model.parameters())# 7.定义训练方法
def train_model(model,device,train_loader,optimizer,epoch):model.train()  #模型训练for batch_index,(data,target) in enumerate(train_loader):data,target=data.to(device),target.to(device)  #部署到DEVICE上去optimizer.zero_grad()  #梯度初始化为0output=model(data) #训练后的结果loss=F.cross_entropy(output,target)  #计算损失
#         pred=output.max(1,keepdim=True)  #pred=output.argmax(dim=1)  #找到概率值最大的下标loss.backward()  #反向传播optimizer.step()  #参数优化if batch_index%3000==0:print("Train Epoch:{} \t Loss:{:.6f}".format(epoch,loss.item()))# 8.定义测试方法
def test_model(model,device,test_loader):model.eval()  #模型验证correct=0.0  #正确率test_loss=0.0  #测试损失with torch.no_grad():  #不会计算梯度,也不会进行反向传播for data,target in test_loader:data,target=data.to(device),target.to(device)  #部署到DEVICE上output=model(data) #测试数据test_loss+=F.cross_entropy(output,target).item()  #计算测试损失pred=output.max(1,keepdim=True)[1]  #[0]值  [1]索引  找到概率值最大的下标
#           pred=torch.max(output,dimm=1)
#           pred=output.argmax(dim=1)correct+=pred.eq(target.view_as(pred)).sum().item()  #累计正确的值test_loss/=len(test_loader.dataset)print("Test -- Average loss:{:.4f},Accuracy:{:.3f}\n".format(test_loss,100.0*correct/len(test_loader.dataset)))

2.7 调用以上方法

# 9.调用方法
for epoch in range(1,EPOCHS+1):train_model(model,DEVICE,train_loader,optimizer,epoch)test_model(model,DEVICE,test_loader)

3 运行结果

从结果中我们看到,在Epoch 8中准确率达到了99.090,说明这个模型不错。当然,我们也可以再改进一下网络模型,还请多多指教,一起学习哈!!!

PyTorch手写字体识别MNIST相关推荐

  1. 手写字体识别 --MNIST数据集

    Matlab 手写字体识别 忙过这段时间后,对于上次读取的Matlab内部数据实现的识别,我回味了一番,觉得那个实在太小.所以打算把数据换成[MNIST数据集][1]. 基础思想还是相同的,使用Tre ...

  2. pytorch实现手写字体识别(Mnist数据集)

    1.加载数据集 一个快速体验学习的小tip在google的云jupyter上做实验,速度快的飞起. import torch from torch.nn import Linear, ReLU imp ...

  3. pytorch CNN手写字体识别

    ## """CNN手写字体识别"""import torch import torch.nn as nn from torch.autogr ...

  4. pytorch rnn 实现手写字体识别

    pytorch rnn 实现手写字体识别 构建 RNN 代码 加载数据 使用RNN 训练 和测试数据 构建 RNN 代码 import torch import torch.nn as nn from ...

  5. 使用Pytorch实现手写数字识别(Mnist数据集)

    目标 知道如何使用Pytorch完成神经网络的构建 知道Pytorch中激活函数的使用方法 知道Pytorch中torchvision.transforms中常见图形处理函数的使用 知道如何训练模型和 ...

  6. 人工智能入门第一课:手写字体识别及可视化项目(手写画板)(mnist)

    人工智能入门第一课:手写字体识别及可视化项目(手写画板)(mnist),使用技术(Django+js+tensorflow+html+bootstrap+inspinia框架) 直接上图,项目效果 1 ...

  7. MNIST手写字体识别入门编译过程遇到的问题及解决

    MNIST手写字体识别入门编译过程遇到的问题及解决 以MNIST手写字体识别作为神经网络及各种网络模型的作为练手,将遇到的问题在这里记录与交流. 激活tensorflow环境后,运行spyder或者j ...

  8. 【PyTorch学习笔记_04】--- PyTorch(开始动手操作_案例1:手写字体识别)

    手写字体识别的流程 定义超参数(自己定义的参数) 构建transforms, 主要是对图像做变换 下载,加载数据集MNIST 构建网络模型(重要,自己定义) 定义训练方法 定义测试方法 开始训练模型, ...

  9. matlab文字bp识别,MNIST手写字体识别(CNN+BP两种实现)-Matlab程序

    [实例简介] MNIST手写字 Matlab程序,包含BP和CNN程序.不依赖任何库,包含MNIST数据,BP网络可达到98.3%的识别率,CNN可达到99%的识别率.CNN比较耗时,关于CNN的程序 ...

最新文章

  1. 怎么用计算机实现矩阵摹乘法,基于距离矩阵摹乘法的生鲜产品配送路径优化
  2. 树莓派wiringPi常用的函数介绍
  3. tensorflow调用问题解决
  4. java jsonobject_Java实现QQ登录
  5. Python实现图像直方图均衡化算法
  6. 雷军在线求饶:小米5G手机价格厚道,求别骂、求好评、求别带节奏
  7. linux下qt制作日历,基于QT的多功能日历设计与开发.doc
  8. 蓝桥杯---特别数的和(C语言)
  9. 安卓3.0之后的网络访问问题
  10. 将linux文件拷贝到windows,Windows与Linux系统拷贝文件之pscp的使用分享
  11. iOS后台如何保持socket长连接和数据传输
  12. easydarwin 安装_EasyDarwin流媒体服务器
  13. 深度学习与自然语言处理 | 斯坦福CS224n · 课程带学与全套笔记解读(NLP通关指南·完结)
  14. 初学ansys:模态分析及谐响应分析
  15. JavaScript高级—正则表达式(正则表达式在 JavaScript 中的使用、正则表达式中的特殊字符、正则表达式中的替换)
  16. C语言之简单英语词典实现
  17. windows家庭中文版升级至专业版
  18. angular实现国密算法sm2、sm3和sm4的ts版,基于sm-crypto库实现,前后端实现
  19. 请教一下水卡校验算法
  20. C# 可为 null 的类型

热门文章

  1. 很棒的Mobile Image Gallery Web App
  2. word页眉横线去除方法
  3. Centos系统常用软件
  4. vis---network网状拓扑图展示
  5. 苹果将要推出苹果电视
  6. 暴力破解移动硬盘密码
  7. perl和python各自擅长什么领域?
  8. 终于把鸿蒙说明白了,关于安卓系统、AOSP(安卓开源项目)和鸿蒙系统比较
  9. 拼音字母缩写在线翻译源码
  10. 2019年网页设计趋势前瞻,先睹为快!