本文接上一篇:8.训练自己的数据集(1):TFRecords编写与读取,文件是共享的,不再赘述。

1.先看看训练集与测试集大小:
import tensorflow as tf
import os
import numpy as np
from PIL import Image#查看tfrecords中数据集的大小
def total_sample(tfrecord_name):sample_nums = 0for record in tf.python_io.tf_record_iterator(tfrecord_name):sample_nums += 1return  sample_nums
train_total = total_sample('train.tfrecords')
test_total = total_sample('test.tfrecords')
print(train_total)
print(test_total)


注意,数据集大小在后面是有用的,test_total与train_total都有用。

2.定义相关参数
#定义相关参数
#图片的尺寸
WITH = 40
HEIGHT = 40
#图片总共由3类,用于one_hot标签
classes_num = 3
#每批次训练图片数量
batch_size = 300
#将所有图片训练一轮所需要的总的训练次数
total_batch = int(train_total/batch_size)
#总共循环训练轮数
train_epochs = 50
#定义初始学习率
learning_rate = 0.005
3.构建模型并训练测试
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"#读取tfrecords中数据方法
def read_tfrecords(tfrecord_name,batch_size):#将tfrecords读入流中,乱序操作并循环读取filename_queue = tf.train.string_input_producer([tfrecord_name]) reader = tf.TFRecordReader()#返回文件名和文件_, serialized_example = reader.read(filename_queue)#取出文件中包含image和label的feature对象features = tf.parse_single_example(serialized_example,features={'label': tf.FixedLenFeature([], tf.int64),'img_raw' : tf.FixedLenFeature([], tf.string),})#将字符串解析成图像对应的像素数组image = tf.decode_raw(features['img_raw'], tf.uint8)#改变像素数组的大小,彩图是3通道的image = tf.reshape(image, [WITH, HEIGHT, 3])#将像素数组归一化image = tf.cast(image,tf.float32)*(1./255)-0.5#读取标签label = tf.cast(features['label'], tf.int32)#将标签制成one_hotlabel = tf.one_hot(label,depth=classes_num,on_value=1)#按批次大小乱序读取数据x_batch, y_batch = tf.train.shuffle_batch([image,label], batch_size=batch_size, num_threads=1, capacity=30*batch_size,min_after_dequeue=15*batch_size)return x_batch,y_batch#获取训练集数据
xs_train,ys_train = read_tfrecords('train.tfrecords',batch_size)
#获取测试集数据
xs_test,ys_test = read_tfrecords('test.tfrecords',test_total)#定义图片和标签的占位符
#None 表示张量的第一维度可以接受任意长度,3表示图片通道数
x = tf.placeholder(tf.float32,shape = [None,WITH,HEIGHT,3])
#None 表示张量的第一维度可以接受任意长度,class_num表示标签类别个数
y = tf.placeholder(tf.float32,shape = [None,classes_num])
keep_prob  = tf.placeholder(tf.float32)#定义权重及偏置值变量
W1 = tf.Variable(tf.random_normal(([int(WITH/4)*int(HEIGHT/4)*256,1024])))
b1 = tf.Variable(tf.constant(0.1,shape=[1024]))
W2 = tf.Variable(tf.random_normal(([1024,classes_num])))
b2 = tf.Variable(tf.constant(0.1,shape=[classes_num]))#定义隐藏层
def hidden_layer(inputs):#要用激活函数return tf.nn.relu(tf.matmul(inputs,W1)+b1)#定义权重方法:
def get_filter(shape):return tf.Variable(tf.truncated_normal(shape,stddev=0.1))#定义偏置值方法:
def get_bias(shape):return tf.Variable(tf.constant(0.1,shape=shape))w_con1 = get_filter([5,5,3,128])
b_con1 = get_bias([128])
w_con2 = get_filter([5,5,128,256])
b_con2 = get_bias([256])#第一层卷积输出,输出大小为 batch_size * WITH * HEIGHT * 15:
h_conv1 = tf.nn.conv2d(x,filter=w_con1,strides=[1,1,1,1],padding='SAME')+b_con1
#激活函数:
h1 = tf.nn.relu(h_conv1)
#第一层池化输出,输出大小为 batch_size * (WITH/2) *(HEIGHT/2) * 15:
h_pool1 = tf.nn.max_pool(h1,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
#第二层卷积输出,输出大小为 batch_size * (WITH/2) *(HEIGHT/2) * 30:
h_conv2 = tf.nn.conv2d(h_pool1,filter=w_con2,strides=[1,1,1,1],padding='SAME')+b_con2
#激活函数:
h2 = tf.nn.relu(h_conv2)
#第一层池化输出,输出大小为 batch_size * (WITH/4) *(HEIGHT/4) * 30:
h_pool2 = tf.nn.max_pool(h2,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')#将池化后的输出改变下尺寸,
h_reshape = tf.reshape(h_pool2,[-1,int(WITH/4)*int(HEIGHT/4)*256])#隐藏层输出,并使用dropout
h3 = hidden_layer(h_reshape)h_drop1 = tf.nn.dropout(h3,keep_prob)#预测值,这里不用激活函数,因为等下要用tensorflow定义好的softmax交叉熵函数
pred = tf.matmul(h_drop1,W2) + b2#定义交叉熵
cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(logits=pred,labels=y)#定义总的损失函数
loss = tf.reduce_mean(cross_entropy)#定义优化器
opt = tf.train.AdamOptimizer(learning_rate).minimize(loss)#以下是测试模型
correct_prediction = tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
#准确率:
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))with tf.Session() as sess: #必写内容sess.run(tf.global_variables_initializer())coord=tf.train.Coordinator()threads= tf.train.start_queue_runners(coord=coord)#开始训练print("以下是训练模型每轮训练误差:")for epoch in range(train_epochs):#定义平均loss值avg_loss = 0.#循环所有数据for i in range(total_batch):#获取批次训练数据batch_xs,batch_ys = sess.run([xs_train,ys_train])_,c,acc = sess.run([opt,loss,accuracy],feed_dict={x:batch_xs,y:batch_ys,keep_prob:0.8})#平均lossavg_loss += c / total_batch#显示每轮的结果print('Epoch:',epoch+1,',     loss:','{:.9f}'.format(avg_loss),',     accuracy:','{:.5f}'.format(acc))print("\n训练模型结束,以下是测试模型准确率:")#获取测试数据test_xs,test_ys = sess.run([xs_test,ys_test])acc = sess.run(accuracy,feed_dict={x:test_xs,y:test_ys,keep_prob:1.0})print('Accuracy:',acc)#必写内容coord.request_stop()coord.join(threads)


总共训练了50轮,测试正确率57%,很多参数需要调整(比如增加训练次数,添加卷积池化层,改变图片尺寸大小等)。

9.训练自己的数据集(2):卷积神经网络之花卉分类相关推荐

  1. CNN卷积神经网络:花卉分类

    文章目录 简介 一.CNN卷积神经网络基础知识 二.数据集介绍 三.代码实现 读取数据 数据处理 搭建网络 训练网络 测试网络 保存网络 结果展示 总结 简介 本篇文章利用pytorch搭建CNN卷积 ...

  2. 基于卷积神经网络的高光谱分类 CNN+高光谱+印度松数据集

    基于卷积神经网络的高光谱分类 一.研究现状 只考虑到高光谱图像的光谱特征,即: 1.提取特征(小波变换.Gabor纹理分析.形态学剖面) 2.利用分类方法分类(支持向量机.决策树.随机森林.神经网络) ...

  3. cnn 预测过程代码_代码实践 | CNN卷积神经网络之文本分类

    学习目录阿力阿哩哩:深度学习 | 学习目录​zhuanlan.zhihu.com 前面我们介绍了:阿力阿哩哩:一文掌握CNN卷积神经网络​zhuanlan.zhihu.com阿力阿哩哩:代码实践|全连 ...

  4. 毕设 深度学习卷积神经网络的花卉识别

    文章目录 0 前言 1 项目背景 2 花卉识别的基本原理 3 算法实现 3.1 预处理 3.2 特征提取和选择 3.3 分类器设计和决策 3.4 卷积神经网络基本原理 4 算法实现 4.1 花卉图像数 ...

  5. 卷积神经网络在句子分类上的应用[翻译]

    最近翻译几篇paper,也算逼着自己多看看文章.对于一些概念的理解不够透彻可能导致翻译不准确,以及某些术语实在也是找不到合适的中文词,可能会有些别扭或索性没翻.大家将就着看.哪位大神看到了,如有不足还 ...

  6. 基于卷积神经网络的句子分类模型【经典卷积分类附源码链接】

    https://www.toutiao.com/a6680124799831769603/ 基于卷积神经网络的句子分类模型 题目: Convolutional Neural Networks for ...

  7. 论文阅读:Convolutional Neural Networks for Sentence Classification 卷积神经网络的句子分类

    Convolutional Neural Networks for Sentence Classification 卷积神经网络的句子分类 目录 Convolutional Neural Networ ...

  8. 毕业设计 - 题目:基于深度学习卷积神经网络的花卉识别 - 深度学习 机器视觉

    文章目录 0 前言 1 项目背景 2 花卉识别的基本原理 3 算法实现 3.1 预处理 3.2 特征提取和选择 3.3 分类器设计和决策 3.4 卷积神经网络基本原理 4 算法实现 4.1 花卉图像数 ...

  9. 毕业设计 - 基于卷积神经网络的乳腺癌分类 深度学习 医学图像

    文章目录 1 前言 2 前言 3 数据集 3.1 良性样本 3.2 病变样本 4 开发环境 5 代码实现 5.1 实现流程 5.2 部分代码实现 5.2.1 导入库 5.2.2 图像加载 5.2.3 ...

最新文章

  1. 获取预制和获取gameObject
  2. 使用Docker-容器命令介绍
  3. 四种方法实现数组交换
  4. 造轮子是什么意思_程序员为什么热衷于造轮子,升职加薪吗?
  5. teleport 组件的作用_人脸识别综述! 覆盖人脸检测,预处理和特征表示三大核心组件!...
  6. 控件ListView相关属性 1217
  7. java 定时任务spring_Spring实现定时任务调度
  8. python django 动态网页_python27+django1.9创建app的视图及实现动态页面
  9. Android杂谈--ListView之BaseAdapter的使用一(转)
  10. 黄聪:WordPress动作钩子函数add_action()、do_action()源码解析
  11. C#全局钩子和局部钩子记录
  12. 计算机二级c语言带小抄,计算机二级C语言上机题库(可缩印做小抄)..docx
  13. 移动U盘数据恢复,移动U盘数据恢复方法
  14. 深度学习双显卡配置_gpu – 我可以在笔记本电脑上使用intel高清显卡实现深度学习模型...
  15. 空间换时间时间换空间
  16. 农作物病虫害识别技术的发展综述
  17. stm32—火焰传感器的初步使用
  18. 【Android】Error obtaining UI hierarchyError while obtaining UI hierarchy XML file: com.android...
  19. 求树的最大宽度(层次遍历法)
  20. Burg法参数化功率谱估计(Python实现版)

热门文章

  1. 150 元低成本改装家里的门锁,抓好软件硬件,向物联网出发
  2. 叶胜超区块链:Aelf(ELF)---去中心化的云计算区块链网络!
  3. 武汉理工计算机网络教学平台,欢迎访问武汉理工大学计算机科学与技术学院
  4. Profinet转Modbus TCP网关连接脉冲电源通讯配置案例
  5. flac文件提取专辑封面手记
  6. 2021-11-03小程序调查问卷及搭建服务器后台案例
  7. python源码剖析代码例子_Python源码剖析笔记5-模块机制
  8. 小米电视3 android,小米3代/小米电视今日发布!-小米3代,5寸,1080p,Tegra 4,骁龙800,小米电视,47寸, ——快科技(驱动之家旗下媒体)--科技改变未来...
  9. 大厂java程序员教你面试如何介绍项目经验
  10. tiptop 编译运行_TIPTOP GP ERP二次开发FQA问题集