PyTorch手写字体识别MNIST
手写字体识别MNIST
1.准备工作
可以看这个老师的视频进行学习,讲解的非常仔细:视频学习
2.项目代码
2.1 导入模块
# 1.加载相关库
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets,transforms
2.2 定义超参数
每批处理数据为16,训练轮数为10
# 2.定义超参数
BATCH_SIZE=16 #每批处理的数据
DEVICE=torch.device("cuda" if torch.cuda.is_available() else "cpu") #是否用GPU还是CPU训练
EPOCHS=10 #训练数据集的轮次
2.3 构建pipeline,对图像做处理
# 3.构建pipeline,对图像做处理
pipeline=transforms.Compose([transforms.ToTensor(),#将图片转换成tensortransforms.Normalize((0.1307,),(0.3081)) #正则化:降低模型复杂度
])
2.4 下载、加载数据集(需要联网)
如图所示:
# 4.下载、加载数据集
from torch.utils.data import DataLoader# 下载数据集
train_set=datasets.MNIST("data",train=True,download=True,transform=pipeline)test_set=datasets.MNIST("data",train=False,download=True,transform=pipeline)# 加载数据
train_loader=DataLoader(train_set,batch_size=BATCH_SIZE,shuffle=True)test_loader=DataLoader(test_set,batch_size=BATCH_SIZE,shuffle=True)
2.5 构建网络模型(两层)
# 5.构建网络模型
class Digit(nn.Module):def __init__(self):super().__init__()self.conv1=nn.Conv2d(1,10,5) #1:灰度图片的通道 10:输出通道 5:kernelself.conv2=nn.Conv2d(10,20,3) #10:输入通道 20:输出通道 3:kernelself.fc1=nn.Linear(20*10*10,500) #20*10*10:输入通道 500:输出通道self.fc2=nn.Linear(500,10) #输入通道 10:输出通道def forward(self,x):input_size=x.size(0) #atch_size*1*28*28 只取batch_sizex=self.conv1(x) #输入:batch*1*28*28 输出:batch*10*24*24 (28-5+1=24)x=F.relu(x) #保持shape不变,输出:batch*10*24*24x=F.max_pool2d(x,2,2) #输入:batch*10*24*24 输出:batch*10*12*12x=self.conv2(x) #输入:batch*10*12*12 输出:batch*20*10*10 (12-3+1=10)x=F.relu(x)x=x.view(input_size,-1) #拉平,-1 自动计算维度 20*10*10=2000x=self.fc1(x) #输入:batch*2000 输出:batch*500x=F.relu(x) #保持shape不变x=self.fc2(x) #输入:batch*500 输出:batch*10output=F.log_softmax(x,dim=1) #计算分类后,每个数字的概率值,那个数字的概率最大那么就输出数字几return output
2.6 定义优化器,训练方法和测试方法
# 6.定义优化器
model=Digit().to(DEVICE)optimizer=optim.Adam(model.parameters())# 7.定义训练方法
def train_model(model,device,train_loader,optimizer,epoch):model.train() #模型训练for batch_index,(data,target) in enumerate(train_loader):data,target=data.to(device),target.to(device) #部署到DEVICE上去optimizer.zero_grad() #梯度初始化为0output=model(data) #训练后的结果loss=F.cross_entropy(output,target) #计算损失
# pred=output.max(1,keepdim=True) #pred=output.argmax(dim=1) #找到概率值最大的下标loss.backward() #反向传播optimizer.step() #参数优化if batch_index%3000==0:print("Train Epoch:{} \t Loss:{:.6f}".format(epoch,loss.item()))# 8.定义测试方法
def test_model(model,device,test_loader):model.eval() #模型验证correct=0.0 #正确率test_loss=0.0 #测试损失with torch.no_grad(): #不会计算梯度,也不会进行反向传播for data,target in test_loader:data,target=data.to(device),target.to(device) #部署到DEVICE上output=model(data) #测试数据test_loss+=F.cross_entropy(output,target).item() #计算测试损失pred=output.max(1,keepdim=True)[1] #[0]值 [1]索引 找到概率值最大的下标
# pred=torch.max(output,dimm=1)
# pred=output.argmax(dim=1)correct+=pred.eq(target.view_as(pred)).sum().item() #累计正确的值test_loss/=len(test_loader.dataset)print("Test -- Average loss:{:.4f},Accuracy:{:.3f}\n".format(test_loss,100.0*correct/len(test_loader.dataset)))
2.7 调用以上方法
# 9.调用方法
for epoch in range(1,EPOCHS+1):train_model(model,DEVICE,train_loader,optimizer,epoch)test_model(model,DEVICE,test_loader)
3 运行结果
从结果中我们看到,在Epoch 8中准确率达到了99.090,说明这个模型不错。当然,我们也可以再改进一下网络模型,还请多多指教,一起学习哈!!!
PyTorch手写字体识别MNIST相关推荐
- 手写字体识别 --MNIST数据集
Matlab 手写字体识别 忙过这段时间后,对于上次读取的Matlab内部数据实现的识别,我回味了一番,觉得那个实在太小.所以打算把数据换成[MNIST数据集][1]. 基础思想还是相同的,使用Tre ...
- pytorch实现手写字体识别(Mnist数据集)
1.加载数据集 一个快速体验学习的小tip在google的云jupyter上做实验,速度快的飞起. import torch from torch.nn import Linear, ReLU imp ...
- pytorch CNN手写字体识别
## """CNN手写字体识别"""import torch import torch.nn as nn from torch.autogr ...
- pytorch rnn 实现手写字体识别
pytorch rnn 实现手写字体识别 构建 RNN 代码 加载数据 使用RNN 训练 和测试数据 构建 RNN 代码 import torch import torch.nn as nn from ...
- 使用Pytorch实现手写数字识别(Mnist数据集)
目标 知道如何使用Pytorch完成神经网络的构建 知道Pytorch中激活函数的使用方法 知道Pytorch中torchvision.transforms中常见图形处理函数的使用 知道如何训练模型和 ...
- 人工智能入门第一课:手写字体识别及可视化项目(手写画板)(mnist)
人工智能入门第一课:手写字体识别及可视化项目(手写画板)(mnist),使用技术(Django+js+tensorflow+html+bootstrap+inspinia框架) 直接上图,项目效果 1 ...
- MNIST手写字体识别入门编译过程遇到的问题及解决
MNIST手写字体识别入门编译过程遇到的问题及解决 以MNIST手写字体识别作为神经网络及各种网络模型的作为练手,将遇到的问题在这里记录与交流. 激活tensorflow环境后,运行spyder或者j ...
- 【PyTorch学习笔记_04】--- PyTorch(开始动手操作_案例1:手写字体识别)
手写字体识别的流程 定义超参数(自己定义的参数) 构建transforms, 主要是对图像做变换 下载,加载数据集MNIST 构建网络模型(重要,自己定义) 定义训练方法 定义测试方法 开始训练模型, ...
- matlab文字bp识别,MNIST手写字体识别(CNN+BP两种实现)-Matlab程序
[实例简介] MNIST手写字 Matlab程序,包含BP和CNN程序.不依赖任何库,包含MNIST数据,BP网络可达到98.3%的识别率,CNN可达到99%的识别率.CNN比较耗时,关于CNN的程序 ...
最新文章
- 怎么用计算机实现矩阵摹乘法,基于距离矩阵摹乘法的生鲜产品配送路径优化
- 树莓派wiringPi常用的函数介绍
- tensorflow调用问题解决
- java jsonobject_Java实现QQ登录
- Python实现图像直方图均衡化算法
- 雷军在线求饶:小米5G手机价格厚道,求别骂、求好评、求别带节奏
- linux下qt制作日历,基于QT的多功能日历设计与开发.doc
- 蓝桥杯---特别数的和(C语言)
- 安卓3.0之后的网络访问问题
- 将linux文件拷贝到windows,Windows与Linux系统拷贝文件之pscp的使用分享
- iOS后台如何保持socket长连接和数据传输
- easydarwin 安装_EasyDarwin流媒体服务器
- 深度学习与自然语言处理 | 斯坦福CS224n · 课程带学与全套笔记解读(NLP通关指南·完结)
- 初学ansys:模态分析及谐响应分析
- JavaScript高级—正则表达式(正则表达式在 JavaScript 中的使用、正则表达式中的特殊字符、正则表达式中的替换)
- C语言之简单英语词典实现
- windows家庭中文版升级至专业版
- angular实现国密算法sm2、sm3和sm4的ts版,基于sm-crypto库实现,前后端实现
- 请教一下水卡校验算法
- C# 可为 null 的类型