1.工程目录

2.导入data和input_data.py

链接:https://pan.baidu.com/s/1EBNyNurBXWeJVyhNeVnmnA 
提取码:4nnl

3.CNN.py

import tensorflow as tf
import matplotlib.pyplot as plt
import input_datamnist = input_data.read_data_sets('data/', one_hot=True)
trainimg = mnist.train.images
trainlabel = mnist.train.labels
testimg = mnist.test.images
testlabel = mnist.test.labels
print('MNIST ready')n_input = 784
n_output = 10weights = {'wc1': tf.Variable(tf.truncated_normal([3, 3, 1, 64], stddev=0.1)),'wc2': tf.Variable(tf.truncated_normal([3, 3, 64, 128], stddev=0.1)),'wd1': tf.Variable(tf.truncated_normal([7*7*128, 1024], stddev=0.1)),'wd2': tf.Variable(tf.truncated_normal([1024, n_outpot], stddev=0.1)),
}
biases = {'bc1': tf.Variable(tf.random_normal([64], stddev=0.1)),'bc2': tf.Variable(tf.random_normal([128], stddev=0.1)),'bd1': tf.Variable(tf.random_normal([1024], stddev=0.1)),'bd2': tf.Variable(tf.random_normal([n_outpot], stddev=0.1)),
}def conv_basic(_input, _w, _b, _keepratio):_input_r = tf.reshape(_input, shape=[-1, 28, 28, 1])_conv1 = tf.nn.conv2d(_input_r, _w['wc1'], strides=[1, 1, 1, 1], padding='SAME')_conv1 = tf.nn.relu(tf.nn.bias_add(_conv1, _b['bc1']))_pool1 = tf.nn.max_pool(_conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')_pool_dr1 = tf.nn.dropout(_pool1, _keepratio)_conv2 = tf.nn.conv2d(_pool_dr1, _w['wc2'], strides=[1, 1, 1, 1], padding='SAME')_conv2 = tf.nn.relu(tf.nn.bias_add(_conv2, _b['bc2']))_pool2 = tf.nn.max_pool(_conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')_pool_dr2 = tf.nn.dropout(_pool2, _keepratio)_densel = tf.reshape(_pool_dr2, [-1, _w['wd1'].get_shape().as_list()[0]])_fc1 = tf.nn.relu(tf.add(tf.matmul(_densel, _w['wd1']), _b['bd1']))_fc_dr1 = tf.nn.dropout(_fc1, _keepratio)_out = tf.add(tf.matmul(_fc_dr1, _w['wd2']), _b['bd2'])out = {'input_r': _input_r, 'conv1': _conv1, 'pool1': _pool1, 'pool_dr1': _pool_dr1,'conv2': _conv2, 'pool2': _pool2, 'pool_dr2': _pool_dr2, 'densel': _densel,'fc1': _fc1, 'fc_dr1': _fc_dr1, 'out': _out}return outprint('CNN READY')x = tf.placeholder(tf.float32, [None, n_input])
y = tf.placeholder(tf.float32, [None, n_output])
keepratio = tf.placeholder(tf.float32)_pred = conv_basic(x, weights, biases, keepratio)['out']
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(_pred, y))
optm = tf.train.AdamOptimizer(learning_rate=0.01).minimize(cost)
_corr = tf.equal(tf.argmax(_pred, 1), tf.argmax(y, 1))
accr = tf.reduce_mean(tf.cast(_corr, tf.float32))
init = tf.global_variables_initializer()print('GRAPH READY')sess = tf.Session()
sess.run(init)
training_epochs = 15
batch_size = 16
display_step = 1for epoch in range(training_epochs):avg_cost = 0.total_batch = 10for i in range(total_batch):batch_xs, batch_ys = mnist.train.next_batch(batch_size)sess.run(optm, feed_dict={x: batch_xs, y: batch_ys, keepratio: 0.7})avg_cost += sess.run(cost, feed_dict={x: batch_xs, y: batch_ys, keepratio: 1.0})/total_batchif epoch % display_step == 0:print('Epoch: %03d/%03d cost: %.9f' % (epoch, training_epochs, avg_cost))train_acc = sess.run(accr, feed_dict={x: batch_xs, y: batch_ys, keepratio: 1.})print('Training accuracy: %.3f' % (train_acc))res_dict = {'weight': sess.run(weights), 'biases': sess.run(biases)}import pickle
with open('res_dict.pkl', 'wb') as f:pickle.dump(res_dict, f, pickle.HIGHEST_PROTOCOL)

4.test.py

import pickle
import numpy as npdef load_file(path, name):with open(path+''+name+'.pkl', 'rb') as f:return pickle.load(f)res_dict = load_file('', 'res_dict')
print(res_dict['weight']['wc1'])index = 0import input_data
mnist = input_data.read_data_sets('data/', one_hot=True)test_image = mnist.test.images
test_label = mnist.test.labelsimport tensorflow as tfdef conv_basic(_input, _w, _b, _keepratio):_input_r = tf.reshape(_input, shape=[-1, 28, 28, 1])_conv1 = tf.nn.conv2d(_input_r, _w['wc1'], strides=[1, 1, 1, 1], padding='SAME')_conv1 = tf.nn.relu(tf.nn.bias_add(_conv1, _b['bc1']))_pool1 = tf.nn.max_pool(_conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')_pool_dr1 = tf.nn.dropout(_pool1, _keepratio)_conv2 = tf.nn.conv2d(_pool_dr1, _w['wc2'], strides=[1, 1, 1, 1], padding='SAME')_conv2 = tf.nn.relu(tf.nn.bias_add(_conv2, _b['bc2']))_pool2 = tf.nn.max_pool(_conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')_pool_dr2 = tf.nn.dropout(_pool2, _keepratio)_densel = tf.reshape(_pool_dr2, [-1, _w['wd1'].shape[0]])_fc1 = tf.nn.relu(tf.add(tf.matmul(_densel, _w['wd1']), _b['bd1']))_fc_dr1 = tf.nn.dropout(_fc1, _keepratio)_out = tf.add(tf.matmul(_fc_dr1, _w['wd2']), _b['bd2'])out = {'input_r': _input_r, 'conv1': _conv1, 'pool1': _pool1, 'pool_dr1': _pool_dr1,'conv2': _conv2, 'pool2': _pool2, 'pool_dr2': _pool_dr2, 'densel': _densel,'fc1': _fc1, 'fc_dr1': _fc_dr1, 'out': _out}return outn_input = 784
n_output = 10x = tf.placeholder(tf.float32, [None, n_input])
y = tf.placeholder(tf.float32, [None, n_output])keepratio = tf.placeholder(tf.float32)_pred = conv_basic(x, res_dict['weight'], res_dict['biases'], keepratio)['out']
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(_pred, y))_corr = tf.equal(tf.argmax(_pred, 1), tf.argmax(y, 1))
accr = tf.reduce_mean(tf.cast(_corr, tf.float32))init = tf.global_variables_initializer()sess = tf.Session()
sess.run(init)
training_epochs = 1
batch_size = 1
display_step = 1for epoch in range(training_epochs):avg_cost = 0.total_batch = 10for i in range(total_batch):batch_xs, batch_ys = mnist.train.next_batch(batch_size)if epoch % display_step == 0:print('_pre:', np.argmax(sess.run(_pred, feed_dict={x: batch_xs, keepratio: 1. })))print('answer:', np.argmax(batch_ys))

转载于:https://www.cnblogs.com/CK85/p/10258961.html

python,tensorflow,CNN实现mnist数据集的训练与验证正确率相关推荐

  1. 基于MNIST手写体数字识别--含可直接使用代码【Python+Tensorflow+CNN+Keras】

    基于MNIST手写体数字识别--[Python+Tensorflow+CNN+Keras] 1.任务 2.数据集分析 2.1 数据集总体分析 2.2 单个图片样本可视化 3. 数据处理 4. 搭建神经 ...

  2. 【Pytorch分布式训练】在MNIST数据集上训练一个简单CNN网络,将其改成分布式训练

    文章目录 普通单卡训练-GPU 普通单卡训练-CPU 分布式训练-GPU 分布式训练-CPU 租GPU服务器相关 以下代码示例基于:在MNIST数据集上训练一个简单CNN网络,将其改成分布式训练. 普 ...

  3. 基于tensorflow+RNN的MNIST数据集手写数字分类

    2018年9月25日笔记 tensorflow是谷歌google的深度学习框架,tensor中文叫做张量,flow叫做流. RNN是recurrent neural network的简称,中文叫做循环 ...

  4. DL之DNN:自定义2层神经网络TwoLayerNet模型(计算梯度两种方法)利用MNIST数据集进行训练、预测

    DL之DNN:自定义2层神经网络TwoLayerNet模型(计算梯度两种方法)利用MNIST数据集进行训练.预测 导读 利用python的numpy计算库,进行自定义搭建2层神经网络TwoLayerN ...

  5. TF之CNN:CNN实现mnist数据集预测 96%采用placeholder用法+2层C及其max_pool法+隐藏层dropout法+输出层softmax法+目标函数cross_entropy法+

    TF:TF下CNN实现mnist数据集预测 96%采用placeholder用法+2层C及其max_pool法+隐藏层dropout法+输出层softmax法+目标函数cross_entropy法+A ...

  6. 在MNIST数据集上训练一个手写数字识别模型

    使用Pytorch在MNIST数据集上训练一个手写数字识别模型, 代码和参数文件 可下载 1.1 数据下载 import torchvision as tvtraining_sets = tv.dat ...

  7. DL之DNN:自定义2层神经网络TwoLayerNet模型(封装为层级结构)利用MNIST数据集进行训练、预测

    DL之DNN:自定义2层神经网络TwoLayerNet模型(封装为层级结构)利用MNIST数据集进行训练.预测 导读           计算图在神经网络算法中的作用.计算图的节点是由局部计算构成的. ...

  8. DL之DNN:自定义2层神经网络TwoLayerNet模型(封装为层级结构)利用MNIST数据集进行训练、GC对比

    DL之DNN:自定义2层神经网络TwoLayerNet模型(封装为层级结构)利用MNIST数据集进行训练.GC对比 导读           神经网络算法封装为层级结构的作用.在神经网络算法中,通过将 ...

  9. DL之DCGNN:基于TF利用DCGAN实现在MNIST数据集上训练生成新样本

    DL之DCGNN:基于TF利用DCGAN实现在MNIST数据集上训练生成新样本 目录 输出结果 设计思路 实现部分代码 说明:所有图片文件丢失 输出结果 更新-- 设计思路 更新-- 实现部分代码 更 ...

  10. 体验paddle2.0rc版本API-Model--实现Mnist数据集模型训练

    原文链接:体验paddle2.0rc版本API-Model–实现Mnist数据集模型训练:https://blog.csdn.net/weixin_44604887/article/details/1 ...

最新文章

  1. Android分级部门选择界面(一)
  2. 用python画简单的动物-使用Python的turtle画小绵羊
  3. 开发源码常用网站参考
  4. SQL中两个表的某列相减
  5. Web 安全开发规范手册 V1.0
  6. 【ASP.NET开发】.NET三层架构简单解析
  7. 关于mysql的项目_项目中常用的MySQL 优化
  8. 【模型压缩】通道剪枝《Pruning Filters For Efficient ConvNets》论文翻译
  9. (openssh、telnet、vsftpd、nfs、rsync、inotify、samba)
  10. java 测试磁盘io,详解三种Linux测试磁盘IO性能的方法总结,值得收藏
  11. ubuntu mysql自动补全_mysql自动化安装脚本(ubuntu and centos64)
  12. Java基础知识笔记整理(零基础学Java)
  13. 实验报告二:例2-19 一位全加器
  14. 做对了什么与留下了什么 小米上市的背后
  15. 【Web】CSS(No.21)Css经典案例(三)《爱宠知识》
  16. 大数据文本相似去重方案
  17. hiveql 没有left()right()函数,可用substr()替代
  18. 形式化方法-- petri net
  19. 基于百度图像识别api的游戏(coc)辅助工具分析
  20. window脚本介绍

热门文章

  1. VC++6.0环境下调试c语言代码的方法和步骤_附图
  2. java泛型好处及案例
  3. 宜人贷CTO段念:透明与面向目标是管理理念的核心
  4. Oracle数据的导入导出
  5. java 开发小记:如何使用 MyEclipse 开发自己的类库(mylib.jar)以及引用(使用)她...
  6. 每天学一点flash(4) 数组与xml配合使用
  7. Google Gears 体验(1):本机数据库
  8. ArcEngine中拓扑的使用
  9. Hadoop启动jobhistoryserver
  10. 为什么阿里强制 boolean 类型变量不能使用 is 开头