手写数字代码识别(pytorch)实现
数据预览:
import pandas
df=pandas.read_csv('C:\\Users\\HP\\Desktop\\mnist_train.csv',header=None)
df.head()
MNIST的每一行数据包含785个值。第一个值是图像所表示的数字,其余的784个值是图像(尺寸为28像素× 28像素)的像素值。¶
我们可以使用info()函数查看DataFrame的概况
df.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 60000 entries, 0 to 59999 Columns: 785 entries, 0 to 784 dtypes: int64(785) memory usage: 359.3 MB
以上结果告诉我们,该DataFrame有60 000行。这对应60 000幅训练图像。同时,我们也可以确认每行有785个值。¶
让我们将一行像素值转换成实际图像来直观地查看一下。
我们使用通用的matplotlib库来显示图像。在下面的代码中,我们导入matplotlib库的pyplot包
完整代码:
#### 导入库
import torch
import torch.nn as nn
import pandas
import matplotlib.pyplot as plt
from torch.utils.data import Dataset #是pytorch加载和导入数据的方式
'''------------构建神经网络类------------'''
class Classifier(nn.Module):#nn.Module是所有类的父类"""分类器"""def __init__(self):#初始化pytorch父类super().__init__()#定义神经网络层self.model=nn.Sequential(nn.Linear(784,200),nn.Sigmoid(),nn.Linear(200,10),nn.Sigmoid())#创建损失函数(均方误差)self.loss_function=nn.MSELoss()#创建优化器,使用简单梯度下降self.optimiser=torch.optim.SGD(self.parameters(),lr=0.01)'''可视化'''#记录训练进展的计数器和列表self.counter=0self.progress=[]def forward(self,inputs):#直接运行模型return self.model(inputs)"""训练器"""def train(self,inputs,targets):#计算网络的输出值outputs=self.forward(inputs)#计算损失值loss=self.loss_function(outputs,targets)"""下一步是使用损失来更新网络的链接权值"""#梯度归零,反向传播,并更新权重self.optimiser.zero_grad()loss.backward()self.optimiser.step()"""可视化"""#每隔10个训练样本增加一次计数器的值,并将损失值添加进列表的末尾self.counter+=1if(self.counter%10==0):self.progress.append(loss.item())#这里使用item()的作用只是为了方便展开一个单值张量,获取里面的数字pass#每10000次训练后打印计数器的值,这样可以了解训练进展的快慢if(self.counter%10000==0):print("counter=",self.counter)pass"""将损失值绘成图"""def plot_progress(self):df=pandas.DataFrame(self.progress,columns=['loss'])df.plot(ylim=(0,1.0),figsize=(16,8),alpha=0.1,marker='.',grid=True,yticks=(0,0.25,0.5))pass
'''-------------创建MnistDataset类--------------'''
class MnistDataset(Dataset):def __init__(self,csv_file):self.data_df=pandas.read_csv(csv_file,header=None)passdef __len__(self):return len(self.data_df)def __getitem__(self,index):#目标图像(标签)label=self.data_df.iloc[index,0]target=torch.zeros((10))target[label]=1.0#图像数据,取值范围是0~255,标准化为0~1image_values=torch.FloatTensor(self.data_df.iloc[index,1:].values)/255.0#返回标签,图像数据张量以及目标张量return label, image_values, targetpass"""可视化"""def plot_image(self,index):arr=self.data_df.iloc[index,1:].values.reshape(28,28)plt.title("label = " + str(self.data_df.iloc[index,0]))plt.imshow(arr,interpolation='none',cmap='Blues')pass
#检查一下到目前为止是否一切正常
mnist_dataset=MnistDataset('C:\\Users\\HP\\Desktop\\mnist_train.csv')
#mnist_dataset.plot_image(9)
"""训练分类器"""
#创建神经网络C=Classifier()
#在MNIST数据集训练神经网络
for label, image_data_tensor, target_tensor in mnist_dataset:C.train(image_data_tensor,target_tensor)pass
# 绘制分类器损失值
C.plot_progress()
# 加载MNIST测试数据
mnist_test_dataset = MnistDataset('C:\\Users\\HP\\Desktop\\mnist_test.csv')
counter= 10000 counter= 20000 counter= 30000 counter= 40000 counter= 50000 counter= 60000
# 绘制分类器损失值
C.plot_progress()
现在我们有了一个训练后的网络,可以进行图像分类了。我们将切换到包含10 000幅图像的MNIST测试数据集。这些是我们的神经网络从来没看到过的图像。让我们用一个新的Dataset对象加载数据集
# 加载MNIST测试数据
mnist_test_dataset = MnistDataset('C:\\Users\\HP\\Desktop\\mnist_test.csv')
# 挑选一幅图像
record = 19
# 绘制图像和标签
mnist_test_dataset.plot_image(record)
让我们看看训练过的神经网络是如何判断这幅图像的。下面的代码继续使用第20幅图像并提取像素值作为image_data。我们使用forward()函数将图像传递并通过神经网络
image_data = mnist_test_dataset[record][1]
# 调用训练后的神经网络
output = C.forward(image_data)
# 绘制输出张量
pandas.DataFrame(output.detach().numpy()).plot(kind='bar',
legend=False, ylim=(0,1))
手写数字代码识别(pytorch)实现相关推荐
- DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Functional)利用MNIST(手写数字图片识别)数据集实现多分类预测
DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Functional)利用MNIST(手写数字图片识别)数据集实现多分类预测 目录 输出结果 设计思路 核心代码 输出结果 下边两张 ...
- DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Sequential)利用MNIST(手写数字图片识别)数据集实现多分类预测
DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Sequential)利用MNIST(手写数字图片识别)数据集实现多分类预测 目录 输出结果 设计思路 核心代码 输出结果 1.10 ...
- DL之DNN:利用DNN【784→50→100→10】算法对MNIST手写数字图片识别数据集进行预测、模型优化
DL之DNN:利用DNN[784→50→100→10]算法对MNIST手写数字图片识别数据集进行预测.模型优化 导读 目的是建立三层神经网络,进一步理解DNN内部的运作机制 目录 输出结果 设计思路 ...
- Dataset之Handwritten Digits:Handwritten Digits(手写数字图片识别)数据集简介、安装、使用方法之详细攻略
Dataset之Handwritten Digits:Handwritten Digits(手写数字图片识别)数据集简介.安装.使用方法之详细攻略 目录 Handwritten Digits数据集的简 ...
- TF之NN:利用DNN算法(SGD+softmax+cross_entropy)对mnist手写数字图片识别训练集(TF自带函数下载)实现87.4%识别
TF之NN:利用DNN算法(SGD+softmax+cross_entropy)对mnist手写数字图片识别训练集(TF自带函数下载)实现87.4%识别 目录 输出结果 代码设计 输出结果 代码设计 ...
- TF:基于CNN(2+1)实现MNIST手写数字图片识别准确率提高到99%
TF:基于CNN(2+1)实现MNIST手写数字图片识别准确率提高到99% 导读 与Softmax回归模型相比,使用两层卷积的神经网络模型借助了卷积的威力,准确率高非常大的提升. 目录 输出结果 代码 ...
- TF:利用是Softmax回归+GD算法实现MNIST手写数字图片识别(10000张图片测试得到的准确率为92%)
TF:利用是Softmax回归+GD算法实现MNIST手写数字图片识别(10000张图片测试得到的准确率为92%) 目录 设计思路 全部代码 设计思路 全部代码 #TF:利用是Softmax回归+GD ...
- Dataset之MNIST:MNIST(手写数字图片识别+ubyte.gz文件)数据集的下载(基于python语言根据爬虫技术自动下载MNIST数据集)
Dataset之MNIST:MNIST(手写数字图片识别+ubyte.gz文件)数据集的下载(基于python语言根据爬虫技术自动下载MNIST数据集) 目录 数据集下载的所有代码 1.主文件 mni ...
- TF之LoR:基于tensorflow利用逻辑回归算LoR法实现手写数字图片识别提高准确率
TF之LoR:基于tensorflow利用逻辑回归算LoR法实现手写数字图片识别提高准确率 目录 输出结果 设计代码 输出结果 设计代码 #TF之LoR:基于tensorflow实现手写数字图片识别准 ...
最新文章
- 嗨:VSCode和IDEA都请安装上这个神奇的插件
- Sublime Text 2/3 Package Control 安装方法(Install Package)
- jQuery中的表单对象属性过滤选择器(四、八)::enabled、:disabled、:checked、:selected...
- Java中的记录类型
- Python数值计算:一 使用Pylab绘图(1)
- react父子组件通信案例
- java 转json_Java转JSON串的几种方式
- 2021-06-01 深入分析偏向锁、轻量级锁和重量级锁
- 《游戏大师Chris Crawford谈互动叙事》一1.2 两种思维方式
- Java学习之基本概念
- 爬虫获取::after_这种反爬虫手段有点意思,看我破了它!
- 南阳理工ACM111
- 中文的括号和英文的括号区别_如何在word里快捷键入六角括号
- [转载]【电子书下载神器】太给力了!你还找不到想要的电子书吗?
- makefile教程_Makefile教程
- 前端笔记1(选择器,动态增添/修改页面元素)
- Python面试题解析之前端、框架和其他
- Tim Sweeney解释为什么Unreal Engine 4全面转向C++
- 2022世界VR产业大会圆满收官,酷雷曼惊艳亮相!
- 聚类 轮廓 matlab,聚类分析
- MATLAB Simulink Example
- MathWorks 中国
热门文章
- mysql8 密码破解
- C语言中的占位符有哪些
- 计算机的数学知识的手抄报图片大全,数学知识手抄报大全
- 用计算机名共享不了,局域网内无法用计算机名访问共享的解决办法
- 转自http://www.u148.net/article/47902.html
- 【转】2018最新版QQ音乐api调用
- linux添加用户并授权ssh登录
- 基于阿里云平台的短信验证码服务API的使用
- hyperledger fabric 实战开发——水产品溯源交易平台(一)
- Java创建服务端和客户端基础(一)多人在线聊天程序实战基础