代码:

# 1 加载必要的库
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets ,transforms# 2 定义超参数
BATCH_SIZE = 64 #每批处理的数据
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") #用gpu还是cpu计算
# DEVICE = torch.device("cpu")
EPOCHS = 10 #训练数据集的轮次# 3 构建pipeline,对图像做处理
pipeline = transforms.Compose([transforms.ToTensor(), #将图片转换成tensortransforms.Normalize((0.1307,), (0.3081,)) #正则化:降低模型复杂度])# 4 下载,加载数据集
from torch.utils.data import DataLoader# 下载数据集
train_set = datasets.MNIST("data", train=True, download=True, transform=pipeline)test_set = datasets.MNIST("data", train=False, download=True, transform=pipeline)#加载数据
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True)# 插入代码,显示MNIST中的图片
with open("./data/MNIST/raw/train-images-idx3-ubyte", "rb") as f:file = f.read()f.close()image1 = [int(str(item).encode("ascii"), 16) for item in file[16 : 16 + 784]]
print(image1)# 保存图片
import cv2 as cv
import numpy as npimage1_np = np.array(image1, dtype=np.uint8).reshape(28, 28, 1)
print(image1_np.shape)cv.imwrite("digit.jpg", image1_np)# 5 构建网络模型
class Digit(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(1, 10, 5) # 1:灰度图片的通道, 10: 输出通道, 5:kernelself.conv2 = nn.Conv2d(10, 20, 3) #1 0:输入通道, 20:输出通道, 3:kernelself.fc1 = nn.Linear(20*10*10, 500) # 20*20*10:输入通道, 500:输出通道self.fc2 = nn.Linear(500, 10) # 500:输入通道, 10:输出通道def forward(self, x):input_size = x.size(0) # batch_sizex = self.conv1(x) # 输入:batch*1*28*28, 输出:batch*10*24*24 (25 -5 +1 = 24)x = F.relu(x) # 激活函数x = F.max_pool2d(x, 2, 2) # 输入:batch*10*24*24, 输出:batch*10*12*12x = self.conv2(x) # 输入:batch*10*12*12, 输出batch*20*10*10 (12 -3 +1 = 10)x = F.relu(x) # 激活函数x = x.view(input_size, -1) # 拉平(一列)x = self.fc1(x) # 输入:batch*2000, 输出:batch*500x = F.relu(x) #保持shape不变x = self.fc2(x) # 输入:batch*500, 输出:batch*10output = F.log_softmax(x, dim=1) # 计算分类后,每个数字的概率值return output# 6 定义优化器
model = Digit().to(DEVICE)optimizer = optim.Adam(model.parameters())   # 7 定义训练方法
def train_model(model, device, train_loader, optimizer, epoch):#模型训练model.train()for batch_index, (data, traget) in enumerate(train_loader):#部署到DEVICE上去data, traget = data.to(device), traget.to(device)# 梯度初始化为0optimizer.zero_grad()# 训练后的结果output = model(data)# 计算损失loss = F.cross_entropy(output, traget)#反向传播loss.backward()# 参数优化optimizer.step()if batch_index % 3000 == 0:print("Train Epoch : {} \t Loss : {:.6f}".format(epoch,loss.item()))                  # 8 定义测试方法
def test_model(model, device, test_loader):#模型验证model.eval()#正确率correct = 0.0#测试损失test_loss = 0.0with torch.no_grad(): # 不会计算梯度,也不会进行反向传播for data, target in test_loader:# 部署到device上data, target = data.to(device), target.to(device)# 测试数据output = model(data)# 计算测试损失test_loss += F.cross_entropy(output, target).item()# 找到概率值最大的下标pred = output.max(1, keepdim=True)[1] # 返回 (值, 索引)#pred = torch.max(output, dim=1)#pred = output.argmax(dim=1)# 累计正确的值correct += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)print("Test -- Average loss : {:.4f}, Accuracy : {:.3f}\n".format(test_loss, 100.0 * correct / len(test_loader.dataset))) # 9 gpu调用方法 7 、 8
import time
t1 = time.time()
for epoch in range(1, EPOCHS + 1):train_model(model, DEVICE, train_loader, optimizer, epoch)test_model(model, DEVICE, test_loader)
t2 = time.time()

输出结果:

计算调用GPU后训练模型所用时间:

# 计算时长
spend_time = t2 - t1 if torch.cuda.is_available() == True:print("调用GPU累计计算时长:{}".format(spend_time))
else:print("调用CPU累计计算时长:{}".format(spend_time))

结果:

PyTorch MNIST手写字体识别相关推荐

  1. MNIST手写字体识别入门编译过程遇到的问题及解决

    MNIST手写字体识别入门编译过程遇到的问题及解决 以MNIST手写字体识别作为神经网络及各种网络模型的作为练手,将遇到的问题在这里记录与交流. 激活tensorflow环境后,运行spyder或者j ...

  2. matlab文字bp识别,MNIST手写字体识别(CNN+BP两种实现)-Matlab程序

    [实例简介] MNIST手写字 Matlab程序,包含BP和CNN程序.不依赖任何库,包含MNIST数据,BP网络可达到98.3%的识别率,CNN可达到99%的识别率.CNN比较耗时,关于CNN的程序 ...

  3. pytorch CNN手写字体识别

    ## """CNN手写字体识别"""import torch import torch.nn as nn from torch.autogr ...

  4. pytorch MNIST 手写数字识别 + 使用自己的测试集 + 数据增强后再训练

    文章目录 1. MNIST 手写数字识别 2. 聚焦数据集扩充后的模型训练 3. pytorch 手写数字识别基本实现 3.1完整代码及 MNIST 测试集测试结果 3.1.1代码 3.1.2 MNI ...

  5. pytorch应用于MNIST手写字体识别

    前言 手写字体MNIST数据集是一组常见的图像,其常用于测评和比较机器学习算法的性能,本文使用pytorch框架来实现对该数据集的识别,并对结果进行逐步的优化. 一.数据集 MNIST数据集是由28x ...

  6. linux手写数字识别,OpenCV 3.0中的SVM训练 mnist 手写字体识别

    前言: SVM(支持向量机)一种训练分类器的学习方法 mnist 是一个手写字体图像数据库,训练样本有60000个,测试样本有10000个 LibSVM 一个常用的SVM框架 OpenCV3.0 中的 ...

  7. TensorFlow | 使用Tensorflow带你实现MNIST手写字体识别

    github:https://github.com/MichaelBeechan CSDN:https://blog.csdn.net/u011344545 涉及代码:https://github.c ...

  8. pytorch实现手写字体识别(Mnist数据集)

    1.加载数据集 一个快速体验学习的小tip在google的云jupyter上做实验,速度快的飞起. import torch from torch.nn import Linear, ReLU imp ...

  9. (二)Tensorflow搭建卷积神经网络实现MNIST手写字体识别及预测

    1 搭建卷积神经网络 1.0 网络结构 图1.0 卷积网络结构 1.2 网络分析 序号 网络层 描述 1 卷积层 一张原始图像(28, 28, 1),batch=1,经过卷积处理,得到图像特征(28, ...

最新文章

  1. bzoj2724: [Violet 6]蒲公英(分块)
  2. 如何在WORD中设置标题1与标题2编号样式不一样
  3. 10 过滤器和监听器
  4. android添加时间,添加加载时间记录函数
  5. Java13的API_JAVA基础--JAVA API常见对象(其他API)13
  6. Applied Functional Analysis(Applications to Mathematical Physics ) E.Zeidler
  7. collections模块的Counter类
  8. Linux 命令整理
  9. 学计算机的能看出批图吗,P图P的好,女朋友满街跑,P图先学计算机,清华和这些学校少不了...
  10. 【论文】赛尔原创 | EMNLP 2019基于知识库检索器的实体一致性端到端任务型对话系统...
  11. 板绘如何厚涂?绘画时应该怎样厚涂?教你利用SAI结合数位板进行厚涂绘画!
  12. 千方百剂创建账套服务器文件,千方百剂数据库设置教程.docx
  13. Sublime Text3配置LaTeX环境及使用Sumatra PDF作为阅读器——亲测可用
  14. C语言的程序设计流程、特点及要求
  15. 使用栈(非调用)判断该字符串是否中心对称,如 abccba 即为 中心对称 字符串
  16. itunes登录时显示服务器失败怎么办,苹果手机itunes验证失败怎么办
  17. 什么是pisa测试_PISA测试是什么?
  18. dubbo学习笔记(一)——dubbo的作用及简单应用
  19. 使用cmd命令远程重启服务器
  20. 摄像头更改“友好名称“方法

热门文章

  1. iptv原版固件_官方固件不给力?咱自己DIY!手把手教你修改固件!
  2. 跑路的互联网金融公司那么多 众筹光伏电站这件事靠谱吗?
  3. Linux学习(2)----一些操作
  4. 角度的均值与标准差(circular data/ directional statistics)
  5. 超出社保公积金免税上限,纳税方案
  6. 如何看待Java饱和难找工作的现象?
  7. layui 审核按钮 及 代码实现
  8. linux mmap /dev/zero,/dev/null 和 /dev/zero误删除重建方法
  9. Ubuntu通过fim或xdg-open在终端中打开图片
  10. IEC 60794-2-10:2023 室内光纤电缆 - 单工和双工电缆系列规范