08_04基于手写数据集_mat保存模型参数
import os
import numpy as np
import tensorflow as tf
from scipy import io
from tensorflow.examples.tutorials.mnist import input_data# 1、设置超参数
learning_rate = 0.001
epochs = 10
batch_size = 128
test_valid_size = 512 # 用于验证或者测试的样本数量。
n_classes = 10
keep_probab = 0.75def conv2d_block(input_tensor, filter_w, filter_b, stride=1):"""实现 卷积 + 偏置项相加 + 激活:param input_tensor::param filter_w::param filter_b::param stride::return:"""conv = tf.nn.conv2d(input=input_tensor, filter=filter_w, strides=[1, stride, stride, 1], padding='SAME')conv = tf.nn.bias_add(conv, filter_b)conv = tf.nn.relu6(conv)return convdef maxpool(input_tensor, k=2):"""池化:param input_tensor::param k::return:"""ksize = [1, k, k, 1]strides = [1, k, k, 1]max_out = tf.nn.max_pool(value=input_tensor, ksize=ksize, strides=strides, padding='SAME')return max_outdef model(input_tensor, keep_prob, pre_trained_weights=None):""":param input_tensor: 输入图片的占位符:param weights::param biases::param keep_prob: 保留概率的占位符:return:""""""'w_conv1:0', 'w_conv2:0', 'w_fc1:0', 'w_logits:0', 'b_conv1:0', 'b_conv2:0', 'b_fc1:0', 'b_logits:0']"""if pre_trained_weights:W = pre_trained_weightsweights = {'conv1': tf.get_variable('w_conv1', dtype=tf.float32,initializer=W['w_conv1:0'], trainable=False),'conv2': tf.get_variable('w_conv2', dtype=tf.float32,initializer=W['w_conv2:0'], trainable=False),'fc1': tf.get_variable('w_fc1', dtype=tf.float32,initializer=W['w_fc1:0'], trainable=True),'logits': tf.get_variable('w_logits', dtype=tf.float32,initializer=W['w_logits:0'], trainable=True),}biases = {'conv1': tf.get_variable('b_conv1', dtype=tf.float32,initializer=np.reshape(W['b_conv1:0'], -1), trainable=False),'conv2': tf.get_variable('b_conv2', dtype=tf.float32,initializer=np.reshape(W['b_conv2:0'], -1), trainable=False),'fc1': tf.get_variable('b_fc1', shape=[1024], dtype=tf.float32,initializer=tf.zeros_initializer()),'logits': tf.get_variable('b_logits', shape=[n_classes], dtype=tf.float32,initializer=tf.zeros_initializer()),}else:weights = {'conv1': tf.get_variable('w_conv1', shape=[5, 5, 1, 32], dtype=tf.float32,initializer=tf.truncated_normal_initializer(stddev=0.1)),'conv2': tf.get_variable('w_conv2', shape=[5, 5, 32, 64], dtype=tf.float32,initializer=tf.truncated_normal_initializer(stddev=0.1)),'fc1': tf.get_variable('w_fc1', shape=[7 * 7 * 64, 1024], dtype=tf.float32,initializer=tf.truncated_normal_initializer(stddev=0.1)),'logits': tf.get_variable('w_logits', shape=[1024, n_classes], dtype=tf.float32,initializer=tf.truncated_normal_initializer(stddev=0.1)),}biases = {'conv1': tf.get_variable('b_conv1', shape=[32], dtype=tf.float32,initializer=tf.zeros_initializer()),'conv2': tf.get_variable('b_conv2', shape=[64], dtype=tf.float32,initializer=tf.zeros_initializer()),'fc1': tf.get_variable('b_fc1', shape=[1024], dtype=tf.float32,initializer=tf.zeros_initializer()),'logits': tf.get_variable('b_logits', shape=[n_classes], dtype=tf.float32,initializer=tf.zeros_initializer()),}# 1、卷积1 [N, 28, 28, 1] ---> [N, 28, 28, 32]conv1 = conv2d_block(input_tensor=input_tensor, filter_w=weights['conv1'], filter_b=biases['conv1'])# 2、池化1 [N, 28, 28, 32] --->[N, 14, 14, 32]pool1 = maxpool(conv1, k=2)# 3、卷积2 [N, 14, 14, 32] ---> [N, 14, 14,64]conv2 = conv2d_block(input_tensor=pool1, filter_w=weights['conv2'], filter_b=biases['conv2'])conv2 = tf.nn.dropout(conv2, keep_prob=keep_prob)# 4、池化1 [N, 14, 14,64] --->[N, 7, 7, 64]pool2 = maxpool(conv2, k=2)# 5、拉平层(flatten) [N, 7, 7, 64] ---> [N, 7*7*64]x_shape = pool2.get_shape()flatten_shape = x_shape[1] * x_shape[2] * x_shape[3]flatted = tf.reshape(pool2, shape=[-1, flatten_shape])# 6、FC1 全连接层fc1 = tf.nn.relu6(tf.matmul(flatted, weights['fc1']) + biases['fc1'])fc1 = tf.nn.dropout(fc1, keep_prob=keep_prob)# 7、logits层logits = tf.add(tf.matmul(fc1, weights['logits']), biases['logits'])with tf.variable_scope('prediction'):prediction = tf.argmax(logits, axis=1)return logits, predictiondef create_dir_path(path):if not os.path.exists(path):os.makedirs(path)print('create file path:{}'.format(path))def store_weights(sess, save_path):# todo 1、获取所有需要持久化的变量# vars_list = tf.global_variables()vars_list = tf.trainable_variables()# 2、执行得到变量的值vars_values = sess.run(vars_list)# todo 3、将变量转换为字典对象mdict = {}for values, var in zip(vars_values, vars_list):# 获取变量的名字name = var.name# 赋值mdict[name] = values# todo 4、保存为matlab数据格式io.savemat(save_path, mdict)print('Saved Vars to files:{}'.format(save_path))def train():# 创建持久化文件夹checkpoint_dir = './model/mnist/matlab/ai20'create_dir_path(checkpoint_dir)graph = tf.Graph()with graph.as_default():# 1、占位符x = tf.placeholder(tf.float32, [None, 28, 28, 1], name='x')y = tf.placeholder(tf.float32, [None, 10], name='y')keep_prob = tf.placeholder_with_default(0.75, shape=None, name='keep_prob')# 2、创建模型图weights_path = './model/mnist/matlab/ai20'files = os.listdir(weights_path)if files:weight_file = os.path.join(weights_path, files[0])if os.path.isfile(weight_file):mdict = io.loadmat(weight_file)logits, prediction = model(x, keep_prob, pre_trained_weights=mdict)print('Load old model continue to train!')else:logits, prediction = model(x, keep_prob)print('No old model, train from scratch!')# 3、损失loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=y))# 优化器optimizer = tf.train.AdamOptimizer(learning_rate)train_opt = optimizer.minimize(loss)# 计算准确率correct_pred = tf.equal(tf.argmax(y, axis=1), prediction)accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))with tf.Session(graph=graph) as sess:sess.run(tf.global_variables_initializer())mnist = input_data.read_data_sets('../datas/mnist', one_hot=True, reshape=False)# print(mnist.train.num_examples)step = 1while True:# 执行训练batch_x, batch_y = mnist.train.next_batch(batch_size=batch_size)feed = {x: batch_x, y: batch_y}_, train_loss, train_acc = sess.run([train_opt, loss, accuracy], feed)print('Step:{} - Train Loss:{:.5f} - Train acc:{:.5f}'.format(step, train_loss, train_acc))# 持久化# if step % 100 == 0:# files = 'model_{:.3f}.mat'.format(train_acc)# save_file = os.path.join(checkpoint_dir, files)# store_weights(sess, save_path=save_file)step += 1# 退出机制if train_acc >0.99:breakif __name__ == '__main__':train()
08_04基于手写数据集_mat保存模型参数相关推荐
- DL之CNN:利用自定义DeepConvNet【7+1】算法对mnist数据集训练实现手写数字识别、模型评估(99.4%)
DL之CNN:利用自定义DeepConvNet[7+1]算法对mnist数据集训练实现手写数字识别.模型评估(99.4%) 目录 输出结果 设计思路 核心代码 输出结果 设计思路 核心代码 netwo ...
- 基于TensorFlow1.4.0的FNN全连接网络识别MNIST手写数据集
MNIST手写数据集是所有新手入门必经的数据集,数据集比较简单,训练集为50000张手写图片,测试集为张手写图片10000,大小都为28*28,不用自己下载,直接从TensorFlow导入即可 后续随 ...
- 深度学习笔记--pytorch从梯度下降到反向传播BP到线性回归实现,以及API调用和手写数据集的实现
梯度下降和反向传播 目标 知道什么是梯度下降 知道什么是反向传播 1. 梯度是什么? 梯度:是一个向量,导数+变化最快的方向(学习的前进方向) 回顾机器学习 收集数据 x x x ,构建机器学习模型 ...
- 太赞了!NumPy 手写所有主流 ML 模型,由普林斯顿博士后 David Bourgin打造的史上最强机器学习基石项目!...
关注上方"深度学习技术前沿",选择"星标公众号", 资源干货,第一时间送达! 用 NumPy 手写所有主流 ML 模型,普林斯顿博士后 David Bourgi ...
- NumPy 手写所有主流 ML 模型,由普林斯顿博士后 David Bourgin打造的史上最强机器学习基石项目!...
关注上方"深度学习技术前沿",选择"星标公众号", 资源干货,第一时间送达! 用 NumPy 手写所有主流 ML 模型,普林斯顿博士后 David Bourgi ...
- 2.2 Mnist手写数据集
2.2 Mnist手写数据集 全连接网络:网络层的每一个结点都与上一层的所有结点相连. 多隐层全连接神经网络: 代码如下: 1. 导入必要的模块 import numpy as np import p ...
- 【机器学习】PCA主成分项目实战:MNIST手写数据集分类
PCA主成分项目实战:MNIST手写数据集分类 PCA处理手写数字集 1 模块加载与数据导入 2 模型创建与应用 手动反爬虫:原博地址 https://blog.csdn.net/lys_828/ar ...
- 【Keras+计算机视觉+Tensorflow】DCGAN对抗生成网络在MNIST手写数据集上实战(附源码和数据集 超详细)
需要源码和数据集请点赞关注收藏后评论区留言私信~~~ 一.生成对抗网络的概念 生成对抗网络(GANs,Generative Adversarial Nets),由Ian Goodfellow在2014 ...
- 4-CNN-demo-03_CNN网络解决手写数据集
文章目录 1.代码 2.结果 3.可视化 1.代码 import tensorflow as tf import os from tensorflow.examples.tutorials.mnist ...
最新文章
- K-BERT | 基于知识图谱的语言表示模型
- 清华大学:人工智能之知识图谱(附PPT)
- python configuration is still_通过Python配置关闭Release优化
- BQ24296充电管理芯片使用过程中的注意事项
- 详解get与post请求方式、content-type与responseType、@Requestbody与@Requestparam的使用场景
- python模块搜索路径_Python模块搜索路径
- 【收藏】机器学习数据集列表:你需要收藏!
- a标签的onclick事件_JavaScript提高:ASP.NET使用easyUI TABS标签显示问题
- java中JTextArea类_Swing JTextArea类
- 黄聪:一个拼图工具的制作思路
- 重构手法之重新组织数据【1】
- 什么是句柄?为什么会有句柄?HANDLE
- php调试利器之phpdbg
- matlab中 晶闸管整流桥导通角_逆变角如何设置,晶闸管2011-6-6
- SSM项目实战之十八:基础数据的修改和删除
- C语言华氏摄氏度转换
- java spring mvc json转对象,SpringMVC中使用@RequestBody,@ResponseBody注解实现Java对象和XML/JSON数据自动转换(上)......
- 小记 events.EventEmitter.call
- 【C++】packaged_task的用法实例
- Android 页面布局xd,页面布局(XD):小尺寸设备上的页面布局《 从设计到代码:布局设计 》...
热门文章
- 企业微信是如何助力企业引流获客,扩充客户池?
- 曲面积分的投影法_重积分3.二重积分的对称性
- Convolutional Networks for Image Semantic Segmentation
- 人在旅途——》云南8天出行日程清单
- 面试时如何做自我介绍
- Skia最新“编译”,绘制中文字符串,加载PNG、BMP图片等资料的整理。
- 游戏配音最重要的两点
- 解决Oracle数据库1521端口telnet不通问题
- nrm v1.2.5版本使用时会出现的问题
- 【报告分享】2021年5G应用场景研究-CTR(附下载)