基于mnist数据的神经网络构建


文章目录

  • 基于mnist数据的神经网络构建
  • 前言
  • 一、网络构建
    • 1.引入库
    • 2.定义超参数
    • 3.下载数据集并对数据集进行预处理
    • 4.数据可视化
    • 5.构建模型
    • 6.实例化模型
    • 7.训练网络
    • 8.可视化训练和测试的损失值和精确度

前言

利用神经网络完成对手写数字进行识别的实例,并且简单的构建一个分类器,相应的步骤:
1:利用pytorch内置函数mnist下载数据集
2:利用torchvision 对数据进行预处理,调用torch.utils建立一个数据 迭代器
3:可是化数据
4:利用nn工具构建神经网络
5:实例化模型,并定义算是函数及其优化器
6:训练模型
7:可视化结果


提示:以下是本篇文章正文内容

一、网络构建

1.引入库

代码如下(示例):

# 导入必要的包
import numpy as np
import torch
# 导入数据集
from torchvision.datasets import mnist
# 导入预处理模块
import torchvision.transforms as transforms
"""
transforms中的常见操作:
torchvision.transforms.CenterCrop(size):进行中心切割,得到给定的size
torchvision.transforms.RandomCrop(size, padding=0):切割中心点的位置随机选取
torchvision.transforms.RandomHorizontalFlip:随机水平翻转;可以给定的概率为0.5。即:一半的概率翻转,一半的概率不翻转
torchvision.transforms.RandomSizedCrop(size, interpolation=2):随机剪切,然后再resize成给定的size大小
torchvision.transforms.Pad(padding, fill=0):所有边用给定的pad value填充。 padding:要填充多少像素 fill:用什么值填充
transforms.Normalize:对数据进行归一化,前一个0.5是平均值,第二个是方差,如果是多通道的就需要设置多个数值
"""
from torch.utils.data import DataLoader
# 导入nn模块
import torch.nn.functional as F
import torch.optim as optim
from torch import nn
from tqdm import tqdm

2.定义超参数

代码如下:

# 定义超参数
trian_batch_size=32
test_batch_size=64
num_epoches = 20
learning_rate=0.01 # 学习率
momentum =0.5 # 动量设置

3.下载数据集并对数据集进行预处理

# transforms.Compose 该函数的作用是用来整合数据处理的操作的,也就是把所有的预处理操作全部放在一个transform里面
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5],[0.5])])# 下载数据集
train_dataset = mnist.MNIST(r"E:\jupter-file\",train=True,transform=transform,download=True)
test_dataset = mnist.MNIST(r"E:\jupter-file\",train=False,transform=transform,download=True)
# dataloader本质是一个可迭代对象,使用iter()访问,不能使用next()访问
train_loader = DataLoader(train_dataset,batch_size=trian_batch_size,shuffle=True)
test_loader = DataLoader(test_dataset,batch_size=test_batch_size,shuffle=False)

4.数据可视化

# ——————————————可视化——————————————————
import matplotlib.pyplot as plt
%matplotlib inline
examples = enumerate(test_loader)
batch_idx ,(example_data,example_label) =next(examples)
fig = plt.figure()
for i in range(6):plt.subplot(2,3,i+1)plt.tight_layout()plt.imshow(example_data[i][0],cmap="gray",interpolation='none')plt.title("ground truth :{}".format(example_label[i]))plt.axis("off")plt.xticks([])plt.yticks([])
# plt.show()

5.构建模型

class Net(nn.Module):"""使用sequential构建网络,Sequential()函数的功能是将网络的层组合在一起"""def __init__(self,in_dim,n_hidden_1,n_hidden_2,out_dim):""":param in_dim:输入的通道数:param n_hidden_1: 第一层神经元数量:param n_hidden_2: 第二层神经元的数量:param out_dim: 最后的输出通道数"""super(Net,self).__init__()self.layer1 = nn.Sequential(nn.Linear(in_dim,n_hidden_1),nn.BatchNorm1d(n_hidden_1))self.layer2 = nn.Sequential(nn.Linear(n_hidden_1,n_hidden_2),nn.BatchNorm1d(n_hidden_2))self.layer3 = nn.Sequential(nn.Linear(n_hidden_2,out_dim))def forward(self,x):x = F.relu(self.layer1(x))x = F.relu(self.layer2(x))x = self.layer3(x)return  x

6.实例化模型

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = "cpu"
# 实例化网络
model = Net(28*28,300,100,10)
model.to(device)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(),lr=learning_rate,momentum=momentum)

7.训练网络


# 训练模型
losses = []
acces = []
eval_losses = []
eval_acces= []
for epoch in range(num_epoches):train_loss = 0train_acc = 0# 随着轮数的变换,改变学习率if epoch%5==0:optimizer.param_groups[0]['lr']*= 0.1 for img,label in tqdm(train_loader):img = img.to(device)label = label.to(device)img = img.view(img.size(0),-1)# 前向传播out = model(img)loss = criterion(out,label)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()# 记录误差train_loss +=loss.item()# 计算分类的准确率_,pred =out.max(1)num_correct = (pred == label).sum().item()acc = num_correct/img.shape[0]train_acc +=acclosses.append(train_loss / len(train_loader))acces.append(train_acc / len(train_loader))# 在测试集上进行测试eval_loss = 0eval_acc = 0# 将模型改为预测模式model.eval()for img,label in tqdm(test_loader):img = img.to(device)label = label.to(device)img = img.view(img.size(0),-1)out = model(img)loss = criterion(out,label)eval_loss +=loss.item()_,pred = out.max(1)num_correct = (pred == label).sum().item()acc = num_correct /img.shape[0]eval_acc +=acceval_losses.append(eval_loss / len(test_loader))eval_acces.append(eval_acc / len(test_loader))print("epoch:{},trian Loss:{:.4f},train acc:{:.4f},Test loss:{:.4f},""test ACC:{:.4f}".format(epoch,train_loss/len(train_loader),train_acc/len(train_loader),eval_loss/len(test_loader),eval_acc/len(test_loader)))

epoch:8,trian Loss:0.0991,train acc:0.9728,Test loss:0.1105,test ACC:0.9666
100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:39<00:00, 46.89it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 157/157 [00:03<00:00, 43.01it/s]
epoch:9,trian Loss:0.0977,train acc:0.9732,Test loss:0.1105,test ACC:0.9667

8.可视化训练和测试的损失值和精确度

plt.title("loss_acc")
plt.plot(np.arange(len(losses)),losses)
plt.plot(np.arange(len(acces)),acces)
plt.plot(np.arange(len(eval_losses)),eval_losses)
plt.plot(np.arange(len(eval_acces)),eval_acces)
plt.legend(["Train Loss","Train acc","Eval Loss","Eval acc"])

torch学习第一天相关推荐

  1. 深度学习第一讲之深度学习基础

    技术交流qq群: 659201069 深度学习第一讲之深度学习基础 转载请注明出处! 本篇博文从what.why.when.who.where.how五个方面来分析深度学习,接下来讲如何入门,我门将通 ...

  2. Android学习第一书

    大家好,我是一名Facebook的工程师,同时也是<第一行代码--Android>的忠实读者. 虽然我最近几年是在国外读书和工作的,但是和很多人一样,我也非常喜欢郭霖的博客以及他写的< ...

  3. MongoDB学习第一篇 --- Mac下使用HomeBrew安装MongoDB

    2019独角兽企业重金招聘Python工程师标准>>> MongoDB学习第一篇 --- Mac下使用HomeBrew安装MongoDB 0.确保mac已经安装了HomeBrew ( ...

  4. jQuery框架学习第一天:开始认识jQuery

    jQuery框架学习第一天:开始认识jQuery jQuery框架学习第二天:jQuery中万能的选择器 jQuery框架学习第三天:如何管理jQuery包装集 jQuery框架学习第四天:使用jQu ...

  5. 201671010140. 2016-2017-2 《Java程序设计》java学习第一周

       java学习第一周        本周是新学期的开端,也是新的学习进程的开端,第一次接触java这门课程,首先书本的厚度就给我一种无形的压力,这注定了,这门课程不会是轻松的,同时一种全新的学习方 ...

  6. React  学习第一天-2018-07-21

    React  学习第一天 1.Dom 和虚拟Dom Dom 是浏览器中实际存在的,虚拟Dom是框架中的,是利用JS代码来模拟DOM. 虚拟Dom 是实现页面的实时更新. Dom树,一个网页的呈现过程, ...

  7. MapServer Tutorial——MapServer7.2.1教程学习——第一节用例实践:Example1.5 Adding a raster layer...

    MapServer Tutorial--MapServer7.2.1教程学习--第一节用例实践:Example1.5 Adding a  raster layer 一.前言 MapServer不仅支持 ...

  8. linux操作系统学什么,Linux学习-第一天-什么是操作系统

    Linux学习--第一天--什么是操作系统? 第一章 什么是Linux 1.1 什么是Linux 1.1.1 计算机:计算的辅助工具 计算机必须要有的组件: 输入单元:如鼠标.键盘.卡片阅读器机,等等 ...

  9. Python中的TCP的客户端UDP学习----第一篇博客

    Python中的TCP的客户端&UDP学习--第一篇博客 PS: 每日的怼人句子"我真想把我的脑子放到你的身体里,让你感受一下智慧的光芒" 先说UDP流程 发送: 创建套接 ...

最新文章

  1. 山景智能创始人黄勇:银行要从数据智能转向业务智能,今天的金融服务难以支撑未来 | MEET2021...
  2. asp.net夜话之九:验证控件(上)
  3. Mixup vs. SamplePairing:ICLR2018投稿论文的两种数据增广方式
  4. 推荐系统中协同过滤算法实现分析
  5. compose应用_带有PostgreSQLDocker Compose for Spring Boot应用程序
  6. 关于java_关于Java基础
  7. tuning-primer.sh 性能调试工具的使用
  8. 【Clickhouse】Clickhouse 多路径存储策略
  9. 信息孤岛影响_企业专访:以“信息化”冲破信息孤岛
  10. #pragma的常用方法讲解(转载)
  11. 易语言大漠进行字库制作的时候出现不能展示二值化区域
  12. 三星S7Edge刷了鉴机大师的Android8的增强版,超级流畅省电_我是亲民_新浪博客
  13. TOPSIS综合评价模型
  14. SPSS23第二版课后习题答案_全新版大学进阶英语综合教程3 Unit1unit3课后习题答案...
  15. 手机数字雨_cmd命令数字雨教程
  16. windows系统电脑实用快捷键
  17. 苹果电脑教程之退出ID账号
  18. C语言知识层次结构图
  19. 停车场无感支付中的“黑科技
  20. 【HomeAssistant接入的设备实现天猫精灵】

热门文章

  1. 计算机学院志愿公益活动,公益工坊暖心房——记计算机科学学院志愿者活动
  2. 二叉排序树的定义及基本操作(构造、查找、插入、删除)递归及非递归算法
  3. 电快速瞬变脉冲群抗扰度(EFT)测试流程
  4. 微信小程序开发常见问题FAQ之四
  5. vue中设置内联样式style 动态绑定背景图backgroundimage不生效问题,以及动态绑定img的src,图片无法显示问题(src=“[object Module]“)
  6. 夜读《匠人手记》-01
  7. Hbase 操作时出现:Server is not running yet
  8. iOS 6发布的启示 —谈互联网产业链变化
  9. 无人超市之后,马云的小卖部也来了
  10. 详解深度学习之 Embedding