数据集:MNIST数据集,代码中会自动下载,不用自己手动下载。数据集很小,不需要GPU设备,可以很好的体会到pytorch的魅力。
模型+训练+预测程序:

import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
import torchvision
from matplotlib import pyplot as plt
from utils import plot_image, plot_curve, one_hot# step1  load dataset
batch_size = 512
train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data', train=True, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data/', train=False, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=False)
x , y = next(iter(train_loader))
print(x.shape, y.shape, x.min(), x.max())
plot_image(x, y, "image_sample")class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(28*28, 256)self.fc2 = nn.Linear(256, 64)self.fc3 = nn.Linear(64, 10)def forward(self, x):# x: [b, 1, 28, 28]# h1 = relu(xw1 + b1)x = F.relu(self.fc1(x))# h2 = relu(h1w2 + b2)x = F.relu(self.fc2(x))# h3 = h2w3 + b3x = self.fc3(x)return x
net = Net()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)train_loss = []
for epoch in range(3):for batch_idx, (x, y) in enumerate(train_loader):#加载进来的图片是一个四维的tensor,x: [b, 1, 28, 28], y:[512]#但是我们网络的输入要是一个一维向量(也就是二维tensor),所以要进行展平操作x = x.view(x.size(0), 28*28)#  [b, 10]out = net(x)y_onehot = one_hot(y)# loss = mse(out, y_onehot)loss = F.mse_loss(out, y_onehot)optimizer.zero_grad()loss.backward()# w' = w - lr*gradoptimizer.step()train_loss.append(loss.item())if batch_idx % 10 == 0:print(epoch, batch_idx, loss.item())plot_curve(train_loss)# we get optimal [w1, b1, w2, b2, w3, b3]total_correct = 0
for x,y in test_loader:x = x.view(x.size(0), 28*28)out = net(x)# out: [b, 10]pred = out.argmax(dim=1)correct = pred.eq(y).sum().float().item()total_correct += correct
total_num = len(test_loader.dataset)
acc = total_correct/total_num
print("acc:", acc)x, y = next(iter(test_loader))
out = net(x.view(x.size(0), 28*28))
pred = out.argmax(dim=1)
plot_image(x, pred, "test")

主程序中调用的函数(注意命名为utils):

import  torch
from    matplotlib import pyplot as pltdef plot_curve(data):fig = plt.figure()plt.plot(range(len(data)), data, color='blue')plt.legend(['value'], loc='upper right')plt.xlabel('step')plt.ylabel('value')plt.show()def plot_image(img, label, name):fig = plt.figure()for i in range(6):plt.subplot(2, 3, i + 1)plt.tight_layout()plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')plt.title("{}: {}".format(name, label[i].item()))plt.xticks([])plt.yticks([])plt.show()def one_hot(label, depth=10):out = torch.zeros(label.size(0), depth)idx = torch.LongTensor(label).view(-1, 1)out.scatter_(dim=1, index=idx, value=1)return out

打印出损失下降的曲线图:

训练3个epoch之后,在测试集上的精度就可以89%左右,可见模型的准确度还是很不错的。
输出六张测试集的图片以及预测结果:

六张图片的预测全部正确。

pytorch实现手写数字图片识别相关推荐

  1. DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Functional)利用MNIST(手写数字图片识别)数据集实现多分类预测

    DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Functional)利用MNIST(手写数字图片识别)数据集实现多分类预测 目录 输出结果 设计思路 核心代码 输出结果 下边两张 ...

  2. DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Sequential)利用MNIST(手写数字图片识别)数据集实现多分类预测

    DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Sequential)利用MNIST(手写数字图片识别)数据集实现多分类预测 目录 输出结果 设计思路 核心代码 输出结果 1.10 ...

  3. DL之DNN:利用DNN【784→50→100→10】算法对MNIST手写数字图片识别数据集进行预测、模型优化

    DL之DNN:利用DNN[784→50→100→10]算法对MNIST手写数字图片识别数据集进行预测.模型优化 导读 目的是建立三层神经网络,进一步理解DNN内部的运作机制 目录 输出结果 设计思路 ...

  4. Dataset之Handwritten Digits:Handwritten Digits(手写数字图片识别)数据集简介、安装、使用方法之详细攻略

    Dataset之Handwritten Digits:Handwritten Digits(手写数字图片识别)数据集简介.安装.使用方法之详细攻略 目录 Handwritten Digits数据集的简 ...

  5. TF之NN:利用DNN算法(SGD+softmax+cross_entropy)对mnist手写数字图片识别训练集(TF自带函数下载)实现87.4%识别

    TF之NN:利用DNN算法(SGD+softmax+cross_entropy)对mnist手写数字图片识别训练集(TF自带函数下载)实现87.4%识别 目录 输出结果 代码设计 输出结果 代码设计 ...

  6. TF:基于CNN(2+1)实现MNIST手写数字图片识别准确率提高到99%

    TF:基于CNN(2+1)实现MNIST手写数字图片识别准确率提高到99% 导读 与Softmax回归模型相比,使用两层卷积的神经网络模型借助了卷积的威力,准确率高非常大的提升. 目录 输出结果 代码 ...

  7. TF:利用是Softmax回归+GD算法实现MNIST手写数字图片识别(10000张图片测试得到的准确率为92%)

    TF:利用是Softmax回归+GD算法实现MNIST手写数字图片识别(10000张图片测试得到的准确率为92%) 目录 设计思路 全部代码 设计思路 全部代码 #TF:利用是Softmax回归+GD ...

  8. Dataset之MNIST:MNIST(手写数字图片识别+ubyte.gz文件)数据集的下载(基于python语言根据爬虫技术自动下载MNIST数据集)

    Dataset之MNIST:MNIST(手写数字图片识别+ubyte.gz文件)数据集的下载(基于python语言根据爬虫技术自动下载MNIST数据集) 目录 数据集下载的所有代码 1.主文件 mni ...

  9. TF之LoR:基于tensorflow利用逻辑回归算LoR法实现手写数字图片识别提高准确率

    TF之LoR:基于tensorflow利用逻辑回归算LoR法实现手写数字图片识别提高准确率 目录 输出结果 设计代码 输出结果 设计代码 #TF之LoR:基于tensorflow实现手写数字图片识别准 ...

最新文章

  1. 【转】Java学习---Java Web基础面试题整理
  2. 《Adobe Photoshop大师班:经典作品与完美技巧赏析》—Alexander Corvus
  3. R语言实战应用-lightgbm 算法优化:不平衡二分类问题(附代码)
  4. JAVA_OA管理系统(四)番外篇:使用Spring注解注入属性
  5. boost::fusion::fused用法的测试程序
  6. Python读写文件的路径,关于os.chdir(path)位置对程序的影响,
  7. 做Web应用程序时应该如何面向对象杂谈
  8. AngularJS控制器和AngularJS过滤器的学习(3)
  9. linux 命令快捷,Linux常见命令快捷方式(示例代码)
  10. 猫途鹰联手携程集团打造面向中国出境旅行者的顶级旅行平台
  11. laravel-admin下使用header头下载
  12. ralink网卡驱动 linux,Ralink for linux Usb无线网卡驱动编译
  13. 04最大类间方差法(OTSU大津法)
  14. Ant Design表格插入图片
  15. matlab 使用.m文件,matlab 编写M文件(函数)
  16. 企业微信给微信好友定时发送图文并茂的消息
  17. Sqlite3实现脏读
  18. 细说大话西游中的经典元素
  19. 青年大学习自动名单核对程序(使用教程)
  20. windows 挂载百度网盘/阿里云盘等(网盘变本地硬盘) alist + raidrive

热门文章

  1. 虚拟机安装centos7
  2. RSA/ECDSA host key has changed 错误
  3. python 斗地主发牌_Python_斗地主发牌程序
  4. DM MPP部署问题
  5. 关键字与保留词,ES2020版
  6. wordpress 邮件_停止在WordPress中使用一次性电子邮件地址
  7. html制作简单框架网页 实现自己的音乐驿站 操作步骤及源文件下载 (播放功能限mp3文件)
  8. 冒烟测试和回归测试的区别
  9. (STM32笔记5)ws2812驱动开发
  10. ISO 9126软件质量模型的6大特性和27个子特性,测试人员建议深入了解