import torch
import torch.nn as nn
import torchvision.datasets as normal_datasets
import torchvision.transforms as transforms
from torch.autograd import Variablenum_epochs = 1
batch_size = 100
learning_rate = 0.001# 将数据处理成Variable, 如果有GPU, 可以转成cuda形式
def get_variable(x):x = Variable(x)return x.cuda() if torch.cuda.is_available() else x# 从torchvision.datasets中加载一些常用数据集
train_dataset = normal_datasets.MNIST(root='./mnist/',                 # 数据集保存路径train=True,                      # 是否作为训练集transform=transforms.ToTensor(), # 数据如何处理, 可以自己自定义download=True)                   # 路径下没有的话, 可以下载# 见数据加载器和batch
test_dataset = normal_datasets.MNIST(root='./mnist/',train=False,transform=transforms.ToTensor())train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False)# 两层卷积
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()# 使用序列工具快速构建self.conv1 = nn.Sequential(nn.Conv2d(1, 16, kernel_size=5, padding=2),nn.BatchNorm2d(16),nn.ReLU(),nn.MaxPool2d(2))self.conv2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=5, padding=2),nn.BatchNorm2d(32),nn.ReLU(),nn.MaxPool2d(2))self.fc = nn.Linear(7 * 7 * 32, 10)def forward(self, x):out = self.conv1(x)out = self.conv2(out)out = out.view(out.size(0), -1)  # reshapeout = self.fc(out)return outcnn = CNN()
if torch.cuda.is_available():cnn = cnn.cuda()# 选择损失函数和优化方法
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate)for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):images = get_variable(images)labels = get_variable(labels)outputs = cnn(images)loss = loss_func(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()if (i + 1) % 100 == 0:print('Epoch [%d/%d], Iter [%d/%d] Loss: %.4f'% (epoch + 1, num_epochs, i + 1, len(train_dataset) // batch_size, loss.data[0]))# 测试模型
cnn.eval()  # 改成测试形态, 应用场景如: dropout
correct = 0
total = 0
for images, labels in test_loader:images = get_variable(images)labels = get_variable(labels)outputs = cnn(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels.data).sum()print(' 测试 准确率: %d %%' % (100 * correct / total))# Save the Trained Model
torch.save(cnn.state_dict(), 'cnn.pkl')

pytorch 入门(二) cnn 手写数字识别相关推荐

  1. python手写多个字母识别_一个带界面的CNN手写数字识别,使用Python(tensorflow, kivy)实现...

    CNN_Handwritten_Digit_Recognizer (CNN手写数字识别) A CNN handwritten digit recognizer with graphical UI, i ...

  2. 卷积神经网络CNN 手写数字识别

    1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...

  3. CNN 手写数字识别

    1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...

  4. 机器学习入门(1)---以手写数字识别为例

    (部分项目代码源自<python深度学习>,吴茂贵等著,机械工业出版社.代码头有标注:部分测试代码来自pytorch官方文档,代码头有标注:部分概念图来源于github,图片下方有标注) ...

  5. 【图像识别】基于卷积神经网络CNN手写数字识别matlab代码

    1 简介 针对传统手写数字的随机性,无规律性等问题,为了提高手写数字识别的检测准确性,本文在研究手写数字区域特点的基础上,提出了一种新的手写数字识别检测方法.首先,对采集的手写数字图像进行预处理,由于 ...

  6. tensorflow入门之MINIST手写数字识别

    最近在学tensorflow,看了很多资料以及相关视频,有没有大佬推荐一下比较好的教程之类的,谢谢.最后还是到了官方网站去,还好有官方文档中文版,今天就结合官方文档以及之前看的教程写一篇关于MINIS ...

  7. Pytorch CNN 手写数字识别 0-9

    使用的软件是pycharm 环境是在anaconda下创的虚拟环境pytorch 整个过程大体为,在画板手写数字,用python代码实现手写数字的批量生成,定义超参数,创建数据集包括训练集和数据集,创 ...

  8. 入门学习MNIST手写数字识别

    一.MNIST数据集 1.MNIST数据集简介 MNIST数据集是一个公开的数据集,相当于深度学习的hello world,用来检验一个模型/库/框架是否有效的一个评价指标. MNIST数据集是由0〜 ...

  9. Pytorch 学习 (一)Minst手写数字识别(含特定函数解析)

    目录 本人目前在跟随csdn博主 "K同学啊"进行365天深度学习训练营进行学习,这是打卡内容 也作为本人学习的记录. 一.准备部分 三.训练模型 四.正式训练 五.输出 MNIS ...

最新文章

  1. 树莓派python3_树莓派4没有python3怎么办
  2. Java提高篇——Java实现多重继承
  3. 管线命令 cut grep
  4. Silverlight for Windows Phone 7开发体验
  5. powerdesigner 生成数据库脚本
  6. python 控制系统音量_pygame学习笔记(4):声音控制
  7. 分布式事务在Sharding-Sphere中的实现
  8. uniapph5授权成功后返回上一页_记一次授权系统的安全测试
  9. 运用事理图谱搞事情:新闻预警、事件监测、文本可视化、出行规划与历时事件流生成
  10. CentOS操作系统(LAMP)安装教程
  11. 算法笔记:二叉树的序列化和反序列化(剑指 Offer 37)
  12. spring data jpa 查询部分字段列名无效问题
  13. 用python设计图案_用 Python 打造属于自己的GUI图形化界面
  14. 架构之美第一章-如何看到一滴水的美丽
  15. matlab 符号函数是什么意思,matlab符号函数定义
  16. 计算机广告制作专业范围,计算机广告制作专业
  17. 几种常考的面试题类型
  18. android网速代码,Android获取网速和下载速度
  19. svn服务器搭建ip指定,mac 局域网svn服务器搭建
  20. 网络安全没有“银弹”

热门文章

  1. STM32硬件错误(HardFault_Handler)位置判断
  2. 信道容量与Shannon公式
  3. Ghost配置1——删除社交Link
  4. 微软重新开源 MS-DOS 1.25/2.0:已诞生 36 年
  5. MySQL系列:innodb源代码分析之线程并发同步机制
  6. ubuntu 安装google浏览器
  7. hibernate基本映射文件
  8. 发现qq的mac输入法2.8,在终端全屏下输入不显示待选文字或单词
  9. IEEE802.11协议栈
  10. Python使用pyserial进行串口通信