运行环境

系统:win10
cpu:i7-6700HQ
gpu:gtx965m
python : 3.6
pytorch :0.3

普通神经网络

class Nueralnetwork(nn.Module):def __init__(self,in_dim,hidden1,hidden2,out_dim):super(Nueralnetwork, self).__init__()self.layer1 = nn.Linear(in_dim,hidden1)self.layer2 = nn.Linear(hidden1,hidden2)self.layer3 = nn.Linear(hidden2,out_dim)def forward(self,x):x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)return xnet = Nueralnetwork(28*28,200,100,10)optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE)
loss_f = nn.CrossEntropyLoss()

第一层200个神经元,第二层100个,运行2个epochs。

gpu运行结果如下:

cpu运行结果如下:

结论:简单的神经网络计算量比简单的逻辑回归大了不少,gpu运算话费时间已经和cpu相仿。
接下来,我们把神经网络换成卷积神经网络(convolution neural network)再来看看gpu和cpu运行的差距,以及准确率的差别。

CNN

代码如下:

import torch
from torch import nn, optim   # nn 神经网络模块 optim优化函数模块
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torchvision import transforms, datasets
from visdom import Visdom  # 可视化处理模块
import time
import numpy as np
# 可视化app
viz = Visdom()# 超参数
BATCH_SIZE = 40
LR = 1e-3
EPOCH = 2
# 判断是否使用gpu
USE_GPU = Trueif USE_GPU:gpu_status = torch.cuda.is_available()
else:gpu_status = False# 数据引入
train_dataset = datasets.MNIST('./mnist', True, transforms.ToTensor(), download=False)
test_dataset = datasets.MNIST('./mnist', False, transforms.ToTensor())train_loader = DataLoader(train_dataset, BATCH_SIZE, True)
# 为加快测试,把测试数据从10000缩小到2000
test_data = torch.unsqueeze(test_dataset.test_data, 1)[:1500]
test_label = test_dataset.test_labels[:1500]
# visdom可视化部分数据
viz.images(test_data[:100], nrow=10)
# 为防止可视化视窗重叠现象,停顿0.5秒
time.sleep(0.5)
if gpu_status:test_data = test_data.cuda()
test_data = Variable(test_data, volatile=True).float()
# 创建线图可视化窗口
line = viz.line(np.arange(10))# 创建cnn神经网络
class CNN(nn.Module):def __init__(self, in_dim, n_class):super(CNN, self).__init__()self.conv = nn.Sequential(# channel 为信息高度 padding为图片留白 kernel_size 扫描模块size(5x5)nn.Conv2d(in_channels=in_dim, out_channels=16,kernel_size=5,stride=1, padding=2),nn.ReLU(),# 平面缩减 28x28 >> 14*14nn.MaxPool2d(kernel_size=2),nn.Conv2d(16, 32, 3, 1, 1),nn.ReLU(),# 14x14 >> 7x7nn.MaxPool2d(2))self.fc = nn.Sequential(nn.Linear(32*7*7, 120),nn.Linear(120, n_class))def forward(self, x):out = self.conv(x)out = out.view(out.size(0), -1)out = self.fc(out)return out
net = CNN(1,10)if gpu_status :net = net.cuda()print("#"*26, "使用gpu", "#"*26)
else:print("#" * 26, "使用cpu", "#" * 26)
# loss、optimizer 函数设置
loss_f = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=LR)
# 起始时间设置
start_time = time.time()
# 可视化所需数据点
time_p, tr_acc, ts_acc, loss_p = [], [], [], []
# 创建可视化数据视窗
text = viz.text("<h1>convolution Nueral Network</h1>")
for epoch in range(EPOCH):# 由于分批次学习,输出loss为一批平均,需要累积or平均每个batch的loss,accsum_loss, sum_acc, sum_step = 0., 0., 0.for i, (tx, ty) in enumerate(train_loader, 1):if gpu_status:tx, ty = tx.cuda(), ty.cuda()tx = Variable(tx)ty = Variable(ty)out = net(tx)loss = loss_f(out, ty)sum_loss += loss.data[0]*len(ty)pred_tr = torch.max(out,1)[1]sum_acc += sum(pred_tr==ty).data[0]sum_step += ty.size(0)# 学习反馈optimizer.zero_grad()loss.backward()optimizer.step()# 每40个batch可视化一下数据if i % 40 == 0:if gpu_status:test_data = test_data.cuda()test_out = net(test_data)# 如果用gpu运行out数据为cuda格式需要.cpu()转化为cpu数据 在进行比较pred_ts = torch.max(test_out, 1)[1].cpu().data.squeeze()acc = sum(pred_ts==test_label)/float(test_label.size(0))print("epoch: [{}/{}] | Loss: {:.4f} | TR_acc: {:.4f} | TS_acc: {:.4f} | Time: {:.1f}".format(epoch+1, EPOCH,sum_loss/(sum_step), sum_acc/(sum_step), acc, time.time()-start_time))# 可视化部分time_p.append(time.time()-start_time)tr_acc.append(sum_acc/sum_step)ts_acc.append(acc)loss_p.append(sum_loss/sum_step)viz.line(X=np.column_stack((np.array(time_p), np.array(time_p), np.array(time_p))),Y=np.column_stack((np.array(loss_p), np.array(tr_acc), np.array(ts_acc))),win=line,opts=dict(legend=["Loss", "TRAIN_acc", "TEST_acc"]))# visdom text 支持html语句viz.text("<p style='color:red'>epoch:{}</p><br><p style='color:blue'>Loss:{:.4f}</p><br>""<p style='color:BlueViolet'>TRAIN_acc:{:.4f}</p><br><p style='color:orange'>TEST_acc:{:.4f}</p><br>""<p style='color:green'>Time:{:.2f}</p>".format(epoch, sum_loss/sum_step, sum_acc/sum_step, acc,time.time()-start_time),win=text)sum_loss, sum_acc, sum_step = 0., 0., 0.

gpu运行结果:

可视化:

cpu运行结果:

可视化:

哈哈…… gpu终于翻身把歌唱了,在运行卷积神经网络过程gpu变化不是很大,但是cpu确比之前慢了7倍 。
从准确率上来看,cnn的准确率能达到98%多,而普通神经网络只能达到92%,由于图形很简单,这个差距并没有特别的大,而且我们用的cnn也是最简单的模式。

pytorch + visdom 应用神经网络、CNN 处理手写字体分类相关推荐

  1. PyTorch基础与简单应用:构建卷积神经网络实现MNIST手写数字分类

    文章目录 (一) 问题描述 (二) 设计简要描述 (三) 程序清单 (四) 结果分析 (五) 调试报告 (六) 实验小结 (七) 参考资料 (一) 问题描述 构建卷积神经网络实现MNIST手写数字分类 ...

  2. 基于PyTorch框架的多层全连接神经网络实现MNIST手写数字分类

    多层全连接神经网络实现MNIST手写数字分类 1 简单的三层全连接神经网络 2 添加激活函数 3 添加批标准化 4 训练网络 5 结论 参考资料 先用PyTorch实现最简单的三层全连接神经网络,然后 ...

  3. 神经网络学习(二)Tensorflow-简单神经网络(全连接层神经网络)实现手写字体识别

    神经网络学习(二)神经网络-手写字体识别 框架:Tensorflow 1.10.0 数据集:mnist数据集 策略:交叉熵损失 优化:梯度下降 五个模块:拿数据.搭网络.求损失.优化损失.算准确率 一 ...

  4. 深蓝学院第三章:基于卷积神经网络(CNN)的手写数字识别实践

    参看之前篇章的用全连接神经网络去做手写识别:https://blog.csdn.net/m0_37957160/article/details/114105389?spm=1001.2014.3001 ...

  5. 神经网络学习(三)比较详细 卷积神经网络原理、手写字体识别(卷积网络实现)

    之前写了一篇基于minist数据集(手写数字0-9)的全连接层神经网络,识别率(85%)并不高,这段时间学习了一些卷积神经网络的知识又实践了一把, 识别率(96%左右)确实上来了 ,下面把我的学习过程 ...

  6. 用Python搭建2层神经网络实现mnist手写数字分类

    这是一个用python搭建2层NN(一个隐藏层)识别mnist手写数据集的示例 mnist.py文件提供了mnist数据集(6万张训练图,1万张测试图)的在线下载,每张图片是 28 ∗ 28 28*2 ...

  7. 基于tensorflow的mnist数据集手写字体分类level-1

    本文属于学些tensorflow框架系列的文章,不是注重于算法- 基于之前博文中的工作,已经安装好tensorflow等等的配置工作,开始学习tensorflow框架的使用,本文参考了以下链接,致以敬 ...

  8. PyTorch基础-使用卷积神经网络CNN实现手写数据集识别-07

    import numpy as np import torch from torch import nn,optim from torch.autograd import Variable from ...

  9. 手搓卷积神经网络(CNN)进行手写数字识别(python)

    前言: 本文属于学习笔记性质.为了让自己更深入地理解卷积神经网络,我只用numpy.pandas等几个库手搓了一个识别MNIST数字的CNN.500张图单次训练,准确率70-80%. 注意: 1.代码 ...

最新文章

  1. 软件需求工程与UML建模——第九组第二周工作总结
  2. 拾遗:不用使 sizeof 获取数组大小
  3. 推荐一个ASP.NET的资源网站
  4. Object defineProperty
  5. excel导入linux乱码怎么解决方法,,请大家都来看下,Excel导入有乱码?原因出在哪里?应该怎么解决?...
  6. c语言isblank函数怎么用,ISBLANK函数详解_Excel公式教程
  7. word2vec中的数学模型
  8. python骗局-我终于在生活中用到Python了!!!——用爬虫来揭露骗局真相
  9. 64位CentOS 6.4下安装wine
  10. mysql安装运行(centos)
  11. Fish 环境下如何安装 nvm
  12. 原创 VPP使用心得(十六)静态路由添加流程
  13. 隐式函数声明警告---调用malloc函数但不包含头文件
  14. 测试用例(分析法——详细场景法)
  15. Spring Boot 监听 Activemq 中的特定 topic ,并将数据通过 RabbitMq 发布出去
  16. 两种 Type-C 耳机:模拟耳机 数字耳机
  17. 三极管概念工作原理及其应用
  18. 关于java多态性之父类引用指向子类对象
  19. linux系统安装软件报错,Linux安装软件时报错解决方法
  20. 【京东电商网站主界面仿写——HTML第一部分】

热门文章

  1. php表滑动 其它固定,table固定表头使表单横向滚动
  2. markdown编辑希腊字母
  3. mysql获取每周的周一周日的规则写法
  4. 传统机械硬盘和固态硬盘(SSD)的区别
  5. 如何从视频里提取音乐
  6. 鼠标点击效果变成小手的CSS实现
  7. [Audio] 音频基本属性及概念
  8. 网站空间服务器系统,网站空间操作系统
  9. r7 4700u和r5 3550h 选哪个好
  10. pycharm怎么用html注释,pycharm怎么注释?