pytorch实现手写数字图片识别
数据集:MNIST数据集,代码中会自动下载,不用自己手动下载。数据集很小,不需要GPU设备,可以很好的体会到pytorch的魅力。
模型+训练+预测程序:
import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
import torchvision
from matplotlib import pyplot as plt
from utils import plot_image, plot_curve, one_hot# step1 load dataset
batch_size = 512
train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data', train=True, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data/', train=False, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=False)
x , y = next(iter(train_loader))
print(x.shape, y.shape, x.min(), x.max())
plot_image(x, y, "image_sample")class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(28*28, 256)self.fc2 = nn.Linear(256, 64)self.fc3 = nn.Linear(64, 10)def forward(self, x):# x: [b, 1, 28, 28]# h1 = relu(xw1 + b1)x = F.relu(self.fc1(x))# h2 = relu(h1w2 + b2)x = F.relu(self.fc2(x))# h3 = h2w3 + b3x = self.fc3(x)return x
net = Net()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)train_loss = []
for epoch in range(3):for batch_idx, (x, y) in enumerate(train_loader):#加载进来的图片是一个四维的tensor,x: [b, 1, 28, 28], y:[512]#但是我们网络的输入要是一个一维向量(也就是二维tensor),所以要进行展平操作x = x.view(x.size(0), 28*28)# [b, 10]out = net(x)y_onehot = one_hot(y)# loss = mse(out, y_onehot)loss = F.mse_loss(out, y_onehot)optimizer.zero_grad()loss.backward()# w' = w - lr*gradoptimizer.step()train_loss.append(loss.item())if batch_idx % 10 == 0:print(epoch, batch_idx, loss.item())plot_curve(train_loss)# we get optimal [w1, b1, w2, b2, w3, b3]total_correct = 0
for x,y in test_loader:x = x.view(x.size(0), 28*28)out = net(x)# out: [b, 10]pred = out.argmax(dim=1)correct = pred.eq(y).sum().float().item()total_correct += correct
total_num = len(test_loader.dataset)
acc = total_correct/total_num
print("acc:", acc)x, y = next(iter(test_loader))
out = net(x.view(x.size(0), 28*28))
pred = out.argmax(dim=1)
plot_image(x, pred, "test")
主程序中调用的函数(注意命名为utils):
import torch
from matplotlib import pyplot as pltdef plot_curve(data):fig = plt.figure()plt.plot(range(len(data)), data, color='blue')plt.legend(['value'], loc='upper right')plt.xlabel('step')plt.ylabel('value')plt.show()def plot_image(img, label, name):fig = plt.figure()for i in range(6):plt.subplot(2, 3, i + 1)plt.tight_layout()plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')plt.title("{}: {}".format(name, label[i].item()))plt.xticks([])plt.yticks([])plt.show()def one_hot(label, depth=10):out = torch.zeros(label.size(0), depth)idx = torch.LongTensor(label).view(-1, 1)out.scatter_(dim=1, index=idx, value=1)return out
打印出损失下降的曲线图:
训练3个epoch之后,在测试集上的精度就可以89%左右,可见模型的准确度还是很不错的。
输出六张测试集的图片以及预测结果:
六张图片的预测全部正确。
pytorch实现手写数字图片识别相关推荐
- DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Functional)利用MNIST(手写数字图片识别)数据集实现多分类预测
DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Functional)利用MNIST(手写数字图片识别)数据集实现多分类预测 目录 输出结果 设计思路 核心代码 输出结果 下边两张 ...
- DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Sequential)利用MNIST(手写数字图片识别)数据集实现多分类预测
DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Sequential)利用MNIST(手写数字图片识别)数据集实现多分类预测 目录 输出结果 设计思路 核心代码 输出结果 1.10 ...
- DL之DNN:利用DNN【784→50→100→10】算法对MNIST手写数字图片识别数据集进行预测、模型优化
DL之DNN:利用DNN[784→50→100→10]算法对MNIST手写数字图片识别数据集进行预测.模型优化 导读 目的是建立三层神经网络,进一步理解DNN内部的运作机制 目录 输出结果 设计思路 ...
- Dataset之Handwritten Digits:Handwritten Digits(手写数字图片识别)数据集简介、安装、使用方法之详细攻略
Dataset之Handwritten Digits:Handwritten Digits(手写数字图片识别)数据集简介.安装.使用方法之详细攻略 目录 Handwritten Digits数据集的简 ...
- TF之NN:利用DNN算法(SGD+softmax+cross_entropy)对mnist手写数字图片识别训练集(TF自带函数下载)实现87.4%识别
TF之NN:利用DNN算法(SGD+softmax+cross_entropy)对mnist手写数字图片识别训练集(TF自带函数下载)实现87.4%识别 目录 输出结果 代码设计 输出结果 代码设计 ...
- TF:基于CNN(2+1)实现MNIST手写数字图片识别准确率提高到99%
TF:基于CNN(2+1)实现MNIST手写数字图片识别准确率提高到99% 导读 与Softmax回归模型相比,使用两层卷积的神经网络模型借助了卷积的威力,准确率高非常大的提升. 目录 输出结果 代码 ...
- TF:利用是Softmax回归+GD算法实现MNIST手写数字图片识别(10000张图片测试得到的准确率为92%)
TF:利用是Softmax回归+GD算法实现MNIST手写数字图片识别(10000张图片测试得到的准确率为92%) 目录 设计思路 全部代码 设计思路 全部代码 #TF:利用是Softmax回归+GD ...
- Dataset之MNIST:MNIST(手写数字图片识别+ubyte.gz文件)数据集的下载(基于python语言根据爬虫技术自动下载MNIST数据集)
Dataset之MNIST:MNIST(手写数字图片识别+ubyte.gz文件)数据集的下载(基于python语言根据爬虫技术自动下载MNIST数据集) 目录 数据集下载的所有代码 1.主文件 mni ...
- TF之LoR:基于tensorflow利用逻辑回归算LoR法实现手写数字图片识别提高准确率
TF之LoR:基于tensorflow利用逻辑回归算LoR法实现手写数字图片识别提高准确率 目录 输出结果 设计代码 输出结果 设计代码 #TF之LoR:基于tensorflow实现手写数字图片识别准 ...
最新文章
- 【转】Java学习---Java Web基础面试题整理
- 《Adobe Photoshop大师班:经典作品与完美技巧赏析》—Alexander Corvus
- R语言实战应用-lightgbm 算法优化:不平衡二分类问题(附代码)
- JAVA_OA管理系统(四)番外篇:使用Spring注解注入属性
- boost::fusion::fused用法的测试程序
- Python读写文件的路径,关于os.chdir(path)位置对程序的影响,
- 做Web应用程序时应该如何面向对象杂谈
- AngularJS控制器和AngularJS过滤器的学习(3)
- linux 命令快捷,Linux常见命令快捷方式(示例代码)
- 猫途鹰联手携程集团打造面向中国出境旅行者的顶级旅行平台
- laravel-admin下使用header头下载
- ralink网卡驱动 linux,Ralink for linux Usb无线网卡驱动编译
- 04最大类间方差法(OTSU大津法)
- Ant Design表格插入图片
- matlab 使用.m文件,matlab 编写M文件(函数)
- 企业微信给微信好友定时发送图文并茂的消息
- Sqlite3实现脏读
- 细说大话西游中的经典元素
- 青年大学习自动名单核对程序(使用教程)
- windows 挂载百度网盘/阿里云盘等(网盘变本地硬盘) alist + raidrive