TensorFlow详解猫狗识别 (三)--训练过程
感悟
在这段时间中,同时测试了几个神经网络的模型(LeNet、AlexNet、VGG16...)
感受到了调节超参数的重要性,单单对于LeNet来说,得出,当:
batch_size = 32
lr = 0.0001
max_step = 6000~10000
时函数收敛比较快,训练步数介于6000~10000时,训练出来的效果比较好,之前训练了一晚上100000步,第二天过来看预测结果,满心欢喜看到train_acc达到了100%,但用来识别精度大概在75%,模型出现了严重的过拟合,训练8000步的时候,预测结果很理想,100张网图,基本没有错。
而当lr设太小的话,模型不收敛或收敛的太慢,lr太大的话模型会出现震荡。
下面直接给出训练代码,里面会有详细注释。
代码
import os
import numpy as np
import tensorflow as tf
import test
import model
import timeN_CLASSES = 2
IMG_W = 208
IMG_H = 208
BATCH_SIZE = 16
CAPACITY = 2000 #队列中元素个数
MAX_STEP = 8000
learning_rate = 0.0001 #小于0.001print("I'm OK")
train_dir = 'E:\\Pycharm\\tf-01\\Bigwork\\train\\' # 训练图片文件夹
logs_train_dir = 'E:\\Pycharm\\tf-01\\Bigwork\\savenet02\\' # 保存训练结果文件夹train, train_label = test.get_files(train_dir)train_batch, train_label_batch = test.get_batch(train,train_label,IMG_W,IMG_H,BATCH_SIZE,CAPACITY)#训练操作定义
sess = tf.Session()train_logits = model.inference(train_batch, BATCH_SIZE, N_CLASSES)
train_loss = model.losses(train_logits, train_label_batch)
train_op = model.trainning(train_loss, learning_rate)
train_acc = model.evaluation(train_logits, train_label_batch)#train_label_batch = tf.one_hot(train_label_batch,2,1,0)
#测试操作定义summary_op = tf.summary.merge_all()#产生一个writer来写log文件
train_writer = tf.summary.FileWriter(logs_train_dir,sess.graph)
saver = tf.train.Saver()sess.run(tf.global_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord = coord)#加入队列,很重要tra_loss = .0
tra_acc = .0
# val_loss = .0
# val_acc = .0try:start = time.clock()#计算每一个step所花的时间for step in np.arange(MAX_STEP):if coord.should_stop():break_,tra_loss_,tra_acc_ = sess.run([train_op,train_loss,train_acc])# val_loss_, val_acc_ = sess.run([test_loss, test_acc])#下面这一段为我为了打印神经网络最后一层变化写的,可以不要'''train,label = sess.run([train_logits,train_label_batch])#print(train)L = []for i in train:max_ = np.argmax(i)L.append(max_)print(L)print(label)'''tra_loss = tra_loss+tra_loss_tra_acc = tra_acc+tra_acc_# val_loss = val_loss+val_loss_# val_acc = val_acc+val_acc_if (step+1) % 50 == 0 and step!=0:end = time.clock()print('Step %d, train loss = %.2f, train accuracy = %.2f%%' % (step+1, tra_loss/50, tra_acc * 100.0/50))#print('Step %d, val loss = %.2f, val accuracy = %.2f%%' % (step, val_loss/50,val_acc*100.0/50))print(str(end-start))tra_loss = .0tra_acc = .0summary_str = sess.run(summary_op)train_writer.add_summary(summary_str, step)start = time.clock()# 每隔2000步,保存一次训练好的模型if step%2000==0 or step == MAX_STEP-1:checkpoint_path = os.path.join(logs_train_dir, 'model.ckpt')saver.save(sess, checkpoint_path, global_step=step)except tf.errors.OutOfRangeError:print('Done training -- epoch limit reached')finally:coord.request_stop()coord.join(threads)
sess.close()
开始训练
下图我是每间隔10个step打印一次结果,可以看到,训练到2000多步的时候,精度已经可以达到了84%。
下一篇:预测
TensorFlow详解猫狗识别 (三)--训练过程相关推荐
- TensorFlow详解猫狗识别(一)--读取自己的数据集
数据集下载 链接: https://pan.baidu.com/s/1SlNAPf3NbgPyf93XluM7Fg 密码: hpn4 数据集分别有12500张cat,12500张dog 读取数据集 数 ...
- Tensorflow实现kaggle猫狗识别(循序渐进进行网络设计)
这篇是tensorflow版本,pytorch版本会在下一篇博客给出 友情提示:尽量上GPU,博主CPU上跑一个VGG16花了1.5h... Tensorflow实现kaggle猫狗识别 数据集获取 ...
- 毕设:基于CNN卷积神经网络的猫狗识别、狗品种识别(Tensorflow、Keras、Kaggle竞赛)
基于卷积神经网络的图像识别算法及其应用研究 毕业快一年了,拿出来分享给大家,我和网上唯一的区别就是,我能够同时实现两个方案(猫狗识别和狗品种识别),我当时也是网上各种查,花了2,3个月的时间,一个萝卜 ...
- 详解pytorch实现猫狗识别98%附代码
详解pytorch实现猫狗识别98%附代码 前言 一.为什么选用pytorch这个框架? 二.实现效果 三.神经网络从头到尾 1.来源:仿照人为处理图片的流程,模拟人们的神经元处理信息的方式 2.总览 ...
- 【卷积神经网络】CNN详解以及猫狗识别实例
文章目录 一.卷积神经网络(CNN)介绍 1.1 整体结构 1.2 说明 1.3 特点 1.4 应用领域 二.配置实验环境 三.猫狗识别实例 3.1 准备数据集 3.2 图片分类 3.3 网络模型搭建 ...
- 深度学习之CNN卷积神经网络详解以及猫狗识别实战
文章目录 CNN 解决了什么问题? 需要处理的数据量太大 图像简单数字化无法保留图像特征 CNN核心思想 局部感知 参数共享 卷积神经网络-CNN 的基本原理 卷积--提取特征 池化层(下采样)--数 ...
- 猫狗大战——基于TensorFlow的猫狗识别(2)
微信公众号:龙跃十二 我是小玉,一个平平无奇的小天才! 上篇文章我们说了关于猫狗大战这个项目的一些准备工作,接下来,我们看看具体的代码详解. 猫狗大战--基于TensorFlow的猫狗识别(1) 文件 ...
- 使用Tensorflow 2进行猫狗分类识别
使用Tensorflow 2进行猫狗分类识别 本文参照了大佬Andrew Ng的所讲解的Tensorflow 2视频所写,本文将其中只适用于Linux的功能以及只适用于Google Colab的功能改 ...
- tensorflow 猫狗识别 数据增强
文章目录 卷积网络参数 网络配置 模型训练和效果展示 完整代码 数据增强 角度旋转 平移变换 缩放 channel_shift 翻转 rescale变化 图像填充 任务介绍: 有猫和狗的图片,需要对这 ...
最新文章
- PMP-【第2章 项目运行环境与项目经理】-2020-12-29(35页-48页)
- mysql proxy读写分离实现_使用mysql-proxy实现读写分离
- LetCode: 5. 最长回文子串
- matlab数据序列的几种滤波器
- Oracle数据类型与.NET中的对应关系
- 存储过程 SQL server(01)
- springcloud(六):配置中心git示例
- 【今日CV 视觉论文速览】 Part2 25 Jan 2019
- Docker 深入理解概念
- C++设计模式之Template Method(模板方法模式)
- iOS Swift 2 2 监听耳机的 插拔的事件
- Qt Toast 一个淡出提示效果
- 安全狗又拿下一场重保胜战 第22届投洽会顺利谢幕
- 凯撒加密的python语言程序_python语言编程实现凯撒密码、凯撒加解密算法、
- 学习笔记15-L298N
- idea 创建一个springboot 项目(hello world)
- 什么是蜘蛛统计 有什么作用?
- 计算机音乐制作专业美国研究生,美国音乐制作专业研究生六大首选音乐学院
- 绝对值不等式(贪心)
- 数据库系列-什么是 JDBC?它的作用是什么?
热门文章
- 「学习笔记」Vue 官方视频教程 2.0版
- 2014-12-20 如果不加班我们就一起去联想玩敏捷吧!!!!
- 设置环境变量配置的简单方法.env
- Windows CmdPHP窗口消失 但程序依旧执行 查找进程ID(PId)并强制结束进程--wmic process where name--taskkill
- 销售宝:软件销售技巧与话术,行业前景分析!
- 100天精通Python丨黑科技篇 —— 20、Python 修图(滤镜、灰度、裁剪、视觉处理、图像分割、特征提取)
- linux+特殊字符文件夹,linux创建带特殊符号的文件夹
- 2亿日活,日均千万级视频上传,快手推荐系统如何应对技术挑战?
- vue圆环进度条_Vue/React圆环进度条
- webpack5学习笔记-3 打包优化的操作