import os
import numpy as np
import tensorflow as tf
from scipy import io
from tensorflow.examples.tutorials.mnist import input_data# 1、设置超参数
learning_rate = 0.001
epochs = 10
batch_size = 128
test_valid_size = 512  # 用于验证或者测试的样本数量。
n_classes = 10
keep_probab = 0.75def conv2d_block(input_tensor, filter_w, filter_b, stride=1):"""实现 卷积 +  偏置项相加 + 激活:param input_tensor::param filter_w::param filter_b::param stride::return:"""conv = tf.nn.conv2d(input=input_tensor, filter=filter_w, strides=[1, stride, stride, 1], padding='SAME')conv = tf.nn.bias_add(conv, filter_b)conv = tf.nn.relu6(conv)return convdef maxpool(input_tensor, k=2):"""池化:param input_tensor::param k::return:"""ksize = [1, k, k, 1]strides = [1, k, k, 1]max_out = tf.nn.max_pool(value=input_tensor, ksize=ksize, strides=strides, padding='SAME')return max_outdef model(input_tensor, keep_prob, pre_trained_weights=None):""":param input_tensor:   输入图片的占位符:param weights::param biases::param keep_prob:     保留概率的占位符:return:""""""'w_conv1:0', 'w_conv2:0', 'w_fc1:0', 'w_logits:0', 'b_conv1:0', 'b_conv2:0', 'b_fc1:0', 'b_logits:0']"""if pre_trained_weights:W = pre_trained_weightsweights = {'conv1': tf.get_variable('w_conv1', dtype=tf.float32,initializer=W['w_conv1:0'], trainable=False),'conv2': tf.get_variable('w_conv2', dtype=tf.float32,initializer=W['w_conv2:0'], trainable=False),'fc1': tf.get_variable('w_fc1', dtype=tf.float32,initializer=W['w_fc1:0'], trainable=True),'logits': tf.get_variable('w_logits', dtype=tf.float32,initializer=W['w_logits:0'], trainable=True),}biases = {'conv1': tf.get_variable('b_conv1', dtype=tf.float32,initializer=np.reshape(W['b_conv1:0'], -1), trainable=False),'conv2': tf.get_variable('b_conv2', dtype=tf.float32,initializer=np.reshape(W['b_conv2:0'], -1), trainable=False),'fc1': tf.get_variable('b_fc1', shape=[1024], dtype=tf.float32,initializer=tf.zeros_initializer()),'logits': tf.get_variable('b_logits', shape=[n_classes], dtype=tf.float32,initializer=tf.zeros_initializer()),}else:weights = {'conv1': tf.get_variable('w_conv1', shape=[5, 5, 1, 32], dtype=tf.float32,initializer=tf.truncated_normal_initializer(stddev=0.1)),'conv2': tf.get_variable('w_conv2', shape=[5, 5, 32, 64], dtype=tf.float32,initializer=tf.truncated_normal_initializer(stddev=0.1)),'fc1': tf.get_variable('w_fc1', shape=[7 * 7 * 64, 1024], dtype=tf.float32,initializer=tf.truncated_normal_initializer(stddev=0.1)),'logits': tf.get_variable('w_logits', shape=[1024, n_classes], dtype=tf.float32,initializer=tf.truncated_normal_initializer(stddev=0.1)),}biases = {'conv1': tf.get_variable('b_conv1', shape=[32], dtype=tf.float32,initializer=tf.zeros_initializer()),'conv2': tf.get_variable('b_conv2', shape=[64], dtype=tf.float32,initializer=tf.zeros_initializer()),'fc1': tf.get_variable('b_fc1', shape=[1024], dtype=tf.float32,initializer=tf.zeros_initializer()),'logits': tf.get_variable('b_logits', shape=[n_classes], dtype=tf.float32,initializer=tf.zeros_initializer()),}# 1、卷积1  [N, 28, 28, 1]  ---> [N, 28, 28, 32]conv1 = conv2d_block(input_tensor=input_tensor, filter_w=weights['conv1'], filter_b=biases['conv1'])# 2、池化1 [N, 28, 28, 32]   --->[N, 14, 14, 32]pool1 = maxpool(conv1, k=2)# 3、卷积2  [N, 14, 14, 32]  ---> [N, 14, 14,64]conv2 = conv2d_block(input_tensor=pool1, filter_w=weights['conv2'], filter_b=biases['conv2'])conv2 = tf.nn.dropout(conv2, keep_prob=keep_prob)# 4、池化1 [N, 14, 14,64]   --->[N, 7, 7, 64]pool2 = maxpool(conv2, k=2)# 5、拉平层(flatten)    [N, 7, 7, 64]  ---> [N, 7*7*64]x_shape = pool2.get_shape()flatten_shape = x_shape[1] * x_shape[2] * x_shape[3]flatted = tf.reshape(pool2, shape=[-1, flatten_shape])# 6、FC1  全连接层fc1 = tf.nn.relu6(tf.matmul(flatted, weights['fc1']) + biases['fc1'])fc1 = tf.nn.dropout(fc1, keep_prob=keep_prob)# 7、logits层logits = tf.add(tf.matmul(fc1, weights['logits']), biases['logits'])with tf.variable_scope('prediction'):prediction = tf.argmax(logits, axis=1)return logits, predictiondef create_dir_path(path):if not os.path.exists(path):os.makedirs(path)print('create file path:{}'.format(path))def store_weights(sess, save_path):# todo 1、获取所有需要持久化的变量# vars_list = tf.global_variables()vars_list = tf.trainable_variables()# 2、执行得到变量的值vars_values = sess.run(vars_list)# todo 3、将变量转换为字典对象mdict = {}for values, var in zip(vars_values, vars_list):# 获取变量的名字name = var.name# 赋值mdict[name] = values# todo 4、保存为matlab数据格式io.savemat(save_path, mdict)print('Saved Vars to files:{}'.format(save_path))def train():# 创建持久化文件夹checkpoint_dir = './model/mnist/matlab/ai20'create_dir_path(checkpoint_dir)graph = tf.Graph()with graph.as_default():# 1、占位符x = tf.placeholder(tf.float32, [None, 28, 28, 1], name='x')y = tf.placeholder(tf.float32, [None, 10], name='y')keep_prob = tf.placeholder_with_default(0.75, shape=None, name='keep_prob')# 2、创建模型图weights_path = './model/mnist/matlab/ai20'files = os.listdir(weights_path)if files:weight_file = os.path.join(weights_path, files[0])if os.path.isfile(weight_file):mdict = io.loadmat(weight_file)logits, prediction = model(x, keep_prob, pre_trained_weights=mdict)print('Load old model continue to train!')else:logits, prediction = model(x, keep_prob)print('No old model, train from scratch!')# 3、损失loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=y))# 优化器optimizer = tf.train.AdamOptimizer(learning_rate)train_opt = optimizer.minimize(loss)# 计算准确率correct_pred = tf.equal(tf.argmax(y, axis=1), prediction)accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))with tf.Session(graph=graph) as sess:sess.run(tf.global_variables_initializer())mnist = input_data.read_data_sets('../datas/mnist', one_hot=True, reshape=False)# print(mnist.train.num_examples)step = 1while True:# 执行训练batch_x, batch_y = mnist.train.next_batch(batch_size=batch_size)feed = {x: batch_x, y: batch_y}_, train_loss, train_acc = sess.run([train_opt, loss, accuracy], feed)print('Step:{} - Train Loss:{:.5f} - Train acc:{:.5f}'.format(step, train_loss, train_acc))# 持久化# if step % 100 == 0:#     files = 'model_{:.3f}.mat'.format(train_acc)#     save_file = os.path.join(checkpoint_dir, files)#     store_weights(sess, save_path=save_file)step += 1# 退出机制if train_acc >0.99:breakif __name__ == '__main__':train()

08_04基于手写数据集_mat保存模型参数相关推荐

  1. DL之CNN:利用自定义DeepConvNet【7+1】算法对mnist数据集训练实现手写数字识别、模型评估(99.4%)

    DL之CNN:利用自定义DeepConvNet[7+1]算法对mnist数据集训练实现手写数字识别.模型评估(99.4%) 目录 输出结果 设计思路 核心代码 输出结果 设计思路 核心代码 netwo ...

  2. 基于TensorFlow1.4.0的FNN全连接网络识别MNIST手写数据集

    MNIST手写数据集是所有新手入门必经的数据集,数据集比较简单,训练集为50000张手写图片,测试集为张手写图片10000,大小都为28*28,不用自己下载,直接从TensorFlow导入即可 后续随 ...

  3. 深度学习笔记--pytorch从梯度下降到反向传播BP到线性回归实现,以及API调用和手写数据集的实现

    梯度下降和反向传播 目标 知道什么是梯度下降 知道什么是反向传播 1. 梯度是什么? 梯度:是一个向量,导数+变化最快的方向(学习的前进方向) 回顾机器学习 收集数据 x x x ,构建机器学习模型 ...

  4. 太赞了!NumPy 手写所有主流 ML 模型,由普林斯顿博士后 David Bourgin打造的史上最强机器学习基石项目!...

    关注上方"深度学习技术前沿",选择"星标公众号", 资源干货,第一时间送达! 用 NumPy 手写所有主流 ML 模型,普林斯顿博士后 David Bourgi ...

  5. NumPy 手写所有主流 ML 模型,由普林斯顿博士后 David Bourgin打造的史上最强机器学习基石项目!...

    关注上方"深度学习技术前沿",选择"星标公众号", 资源干货,第一时间送达! 用 NumPy 手写所有主流 ML 模型,普林斯顿博士后 David Bourgi ...

  6. 2.2 Mnist手写数据集

    2.2 Mnist手写数据集 全连接网络:网络层的每一个结点都与上一层的所有结点相连. 多隐层全连接神经网络: 代码如下: 1. 导入必要的模块 import numpy as np import p ...

  7. 【机器学习】PCA主成分项目实战:MNIST手写数据集分类

    PCA主成分项目实战:MNIST手写数据集分类 PCA处理手写数字集 1 模块加载与数据导入 2 模型创建与应用 手动反爬虫:原博地址 https://blog.csdn.net/lys_828/ar ...

  8. 【Keras+计算机视觉+Tensorflow】DCGAN对抗生成网络在MNIST手写数据集上实战(附源码和数据集 超详细)

    需要源码和数据集请点赞关注收藏后评论区留言私信~~~ 一.生成对抗网络的概念 生成对抗网络(GANs,Generative Adversarial Nets),由Ian Goodfellow在2014 ...

  9. 4-CNN-demo-03_CNN网络解决手写数据集

    文章目录 1.代码 2.结果 3.可视化 1.代码 import tensorflow as tf import os from tensorflow.examples.tutorials.mnist ...

最新文章

  1. K-BERT | 基于知识图谱的语言表示模型
  2. 清华大学:人工智能之知识图谱(附PPT)
  3. python configuration is still_通过Python配置关闭Release优化
  4. BQ24296充电管理芯片使用过程中的注意事项
  5. 详解get与post请求方式、content-type与responseType、@Requestbody与@Requestparam的使用场景
  6. python模块搜索路径_Python模块搜索路径
  7. 【收藏】机器学习数据集列表:你需要收藏!
  8. a标签的onclick事件_JavaScript提高:ASP.NET使用easyUI TABS标签显示问题
  9. java中JTextArea类_Swing JTextArea类
  10. 黄聪:一个拼图工具的制作思路
  11. 重构手法之重新组织数据【1】
  12. 什么是句柄?为什么会有句柄?HANDLE
  13. php调试利器之phpdbg
  14. matlab中 晶闸管整流桥导通角_逆变角如何设置,晶闸管2011-6-6
  15. SSM项目实战之十八:基础数据的修改和删除
  16. C语言华氏摄氏度转换
  17. java spring mvc json转对象,SpringMVC中使用@RequestBody,@ResponseBody注解实现Java对象和XML/JSON数据自动转换(上)......
  18. 小记 events.EventEmitter.call
  19. 【C++】packaged_task的用法实例
  20. Android 页面布局xd,页面布局(XD):小尺寸设备上的页面布局《 从设计到代码:布局设计 》...

热门文章

  1. 企业微信是如何助力企业引流获客,扩充客户池?
  2. 曲面积分的投影法_重积分3.二重积分的对称性
  3. Convolutional Networks for Image Semantic Segmentation
  4. 人在旅途——》云南8天出行日程清单
  5. 面试时如何做自我介绍
  6. Skia最新“编译”,绘制中文字符串,加载PNG、BMP图片等资料的整理。
  7. 游戏配音最重要的两点
  8. 解决Oracle数据库1521端口telnet不通问题
  9. nrm v1.2.5版本使用时会出现的问题
  10. 【报告分享】2021年5G应用场景研究-CTR(附下载)