数据预览:

import pandas
df=pandas.read_csv('C:\\Users\\HP\\Desktop\\mnist_train.csv',header=None)
df.head()

MNIST的每一行数据包含785个值。第一个值是图像所表示的数字,其余的784个值是图像(尺寸为28像素× 28像素)的像素值。¶

我们可以使用info()函数查看DataFrame的概况

df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 60000 entries, 0 to 59999
Columns: 785 entries, 0 to 784
dtypes: int64(785)
memory usage: 359.3 MB

以上结果告诉我们,该DataFrame有60 000行。这对应60 000幅训练图像。同时,我们也可以确认每行有785个值。¶

让我们将一行像素值转换成实际图像来直观地查看一下。

我们使用通用的matplotlib库来显示图像。在下面的代码中,我们导入matplotlib库的pyplot包

完整代码: 

#### 导入库
import torch
import torch.nn as nn
import pandas
import matplotlib.pyplot as plt
from torch.utils.data import Dataset #是pytorch加载和导入数据的方式
'''------------构建神经网络类------------'''
class Classifier(nn.Module):#nn.Module是所有类的父类"""分类器"""def __init__(self):#初始化pytorch父类super().__init__()#定义神经网络层self.model=nn.Sequential(nn.Linear(784,200),nn.Sigmoid(),nn.Linear(200,10),nn.Sigmoid())#创建损失函数(均方误差)self.loss_function=nn.MSELoss()#创建优化器,使用简单梯度下降self.optimiser=torch.optim.SGD(self.parameters(),lr=0.01)'''可视化'''#记录训练进展的计数器和列表self.counter=0self.progress=[]def forward(self,inputs):#直接运行模型return self.model(inputs)"""训练器"""def train(self,inputs,targets):#计算网络的输出值outputs=self.forward(inputs)#计算损失值loss=self.loss_function(outputs,targets)"""下一步是使用损失来更新网络的链接权值"""#梯度归零,反向传播,并更新权重self.optimiser.zero_grad()loss.backward()self.optimiser.step()"""可视化"""#每隔10个训练样本增加一次计数器的值,并将损失值添加进列表的末尾self.counter+=1if(self.counter%10==0):self.progress.append(loss.item())#这里使用item()的作用只是为了方便展开一个单值张量,获取里面的数字pass#每10000次训练后打印计数器的值,这样可以了解训练进展的快慢if(self.counter%10000==0):print("counter=",self.counter)pass"""将损失值绘成图"""def plot_progress(self):df=pandas.DataFrame(self.progress,columns=['loss'])df.plot(ylim=(0,1.0),figsize=(16,8),alpha=0.1,marker='.',grid=True,yticks=(0,0.25,0.5))pass
'''-------------创建MnistDataset类--------------'''
class MnistDataset(Dataset):def __init__(self,csv_file):self.data_df=pandas.read_csv(csv_file,header=None)passdef __len__(self):return len(self.data_df)def __getitem__(self,index):#目标图像(标签)label=self.data_df.iloc[index,0]target=torch.zeros((10))target[label]=1.0#图像数据,取值范围是0~255,标准化为0~1image_values=torch.FloatTensor(self.data_df.iloc[index,1:].values)/255.0#返回标签,图像数据张量以及目标张量return label, image_values, targetpass"""可视化"""def plot_image(self,index):arr=self.data_df.iloc[index,1:].values.reshape(28,28)plt.title("label = " + str(self.data_df.iloc[index,0]))plt.imshow(arr,interpolation='none',cmap='Blues')pass
#检查一下到目前为止是否一切正常
mnist_dataset=MnistDataset('C:\\Users\\HP\\Desktop\\mnist_train.csv')
#mnist_dataset.plot_image(9)
"""训练分类器"""
#创建神经网络C=Classifier()
#在MNIST数据集训练神经网络
for label, image_data_tensor, target_tensor in mnist_dataset:C.train(image_data_tensor,target_tensor)pass
# 绘制分类器损失值
C.plot_progress()
# 加载MNIST测试数据
mnist_test_dataset = MnistDataset('C:\\Users\\HP\\Desktop\\mnist_test.csv')
counter= 10000
counter= 20000
counter= 30000
counter= 40000
counter= 50000
counter= 60000
# 绘制分类器损失值
C.plot_progress()

现在我们有了一个训练后的网络,可以进行图像分类了。我们将切换到包含10 000幅图像的MNIST测试数据集。这些是我们的神经网络从来没看到过的图像。让我们用一个新的Dataset对象加载数据集 

# 加载MNIST测试数据
mnist_test_dataset = MnistDataset('C:\\Users\\HP\\Desktop\\mnist_test.csv')
# 挑选一幅图像
record = 19
# 绘制图像和标签
mnist_test_dataset.plot_image(record)

让我们看看训练过的神经网络是如何判断这幅图像的。下面的代码继续使用第20幅图像并提取像素值作为image_data。我们使用forward()函数将图像传递并通过神经网络 

image_data = mnist_test_dataset[record][1]
# 调用训练后的神经网络
output = C.forward(image_data)
# 绘制输出张量
pandas.DataFrame(output.detach().numpy()).plot(kind='bar',
legend=False, ylim=(0,1))

手写数字代码识别(pytorch)实现相关推荐

  1. DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Functional)利用MNIST(手写数字图片识别)数据集实现多分类预测

    DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Functional)利用MNIST(手写数字图片识别)数据集实现多分类预测 目录 输出结果 设计思路 核心代码 输出结果 下边两张 ...

  2. DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Sequential)利用MNIST(手写数字图片识别)数据集实现多分类预测

    DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Sequential)利用MNIST(手写数字图片识别)数据集实现多分类预测 目录 输出结果 设计思路 核心代码 输出结果 1.10 ...

  3. DL之DNN:利用DNN【784→50→100→10】算法对MNIST手写数字图片识别数据集进行预测、模型优化

    DL之DNN:利用DNN[784→50→100→10]算法对MNIST手写数字图片识别数据集进行预测.模型优化 导读 目的是建立三层神经网络,进一步理解DNN内部的运作机制 目录 输出结果 设计思路 ...

  4. Dataset之Handwritten Digits:Handwritten Digits(手写数字图片识别)数据集简介、安装、使用方法之详细攻略

    Dataset之Handwritten Digits:Handwritten Digits(手写数字图片识别)数据集简介.安装.使用方法之详细攻略 目录 Handwritten Digits数据集的简 ...

  5. TF之NN:利用DNN算法(SGD+softmax+cross_entropy)对mnist手写数字图片识别训练集(TF自带函数下载)实现87.4%识别

    TF之NN:利用DNN算法(SGD+softmax+cross_entropy)对mnist手写数字图片识别训练集(TF自带函数下载)实现87.4%识别 目录 输出结果 代码设计 输出结果 代码设计 ...

  6. TF:基于CNN(2+1)实现MNIST手写数字图片识别准确率提高到99%

    TF:基于CNN(2+1)实现MNIST手写数字图片识别准确率提高到99% 导读 与Softmax回归模型相比,使用两层卷积的神经网络模型借助了卷积的威力,准确率高非常大的提升. 目录 输出结果 代码 ...

  7. TF:利用是Softmax回归+GD算法实现MNIST手写数字图片识别(10000张图片测试得到的准确率为92%)

    TF:利用是Softmax回归+GD算法实现MNIST手写数字图片识别(10000张图片测试得到的准确率为92%) 目录 设计思路 全部代码 设计思路 全部代码 #TF:利用是Softmax回归+GD ...

  8. Dataset之MNIST:MNIST(手写数字图片识别+ubyte.gz文件)数据集的下载(基于python语言根据爬虫技术自动下载MNIST数据集)

    Dataset之MNIST:MNIST(手写数字图片识别+ubyte.gz文件)数据集的下载(基于python语言根据爬虫技术自动下载MNIST数据集) 目录 数据集下载的所有代码 1.主文件 mni ...

  9. TF之LoR:基于tensorflow利用逻辑回归算LoR法实现手写数字图片识别提高准确率

    TF之LoR:基于tensorflow利用逻辑回归算LoR法实现手写数字图片识别提高准确率 目录 输出结果 设计代码 输出结果 设计代码 #TF之LoR:基于tensorflow实现手写数字图片识别准 ...

最新文章

  1. 嗨:VSCode和IDEA都请安装上这个神奇的插件
  2. Sublime Text 2/3 Package Control 安装方法(Install Package)
  3. jQuery中的表单对象属性过滤选择器(四、八)::enabled、:disabled、:checked、:selected...
  4. Java中的记录类型
  5. Python数值计算:一 使用Pylab绘图(1)
  6. react父子组件通信案例
  7. java 转json_Java转JSON串的几种方式
  8. 2021-06-01 深入分析偏向锁、轻量级锁和重量级锁
  9. 《游戏大师Chris Crawford谈互动叙事》一1.2 两种思维方式
  10. Java学习之基本概念
  11. 爬虫获取::after_这种反爬虫手段有点意思,看我破了它!
  12. 南阳理工ACM111
  13. 中文的括号和英文的括号区别_如何在word里快捷键入六角括号
  14. [转载]【电子书下载神器】太给力了!你还找不到想要的电子书吗?
  15. makefile教程_Makefile教程
  16. 前端笔记1(选择器,动态增添/修改页面元素)
  17. Python面试题解析之前端、框架和其他
  18. Tim Sweeney解释为什么Unreal Engine 4全面转向C++
  19. 2022世界VR产业大会圆满收官,酷雷曼惊艳亮相!
  20. 聚类 轮廓 matlab,聚类分析 - MATLAB Simulink Example - MathWorks 中国

热门文章

  1. mysql8 密码破解
  2. C语言中的占位符有哪些
  3. 计算机的数学知识的手抄报图片大全,数学知识手抄报大全
  4. 用计算机名共享不了,局域网内无法用计算机名访问共享的解决办法
  5. 转自http://www.u148.net/article/47902.html
  6. 【转】2018最新版QQ音乐api调用
  7. linux添加用户并授权ssh登录
  8. 基于阿里云平台的短信验证码服务API的使用
  9. hyperledger fabric 实战开发——水产品溯源交易平台(一)
  10. Java创建服务端和客户端基础(一)多人在线聊天程序实战基础