感悟

在这段时间中,同时测试了几个神经网络的模型(LeNet、AlexNet、VGG16...)

感受到了调节超参数的重要性,单单对于LeNet来说,得出,当:

batch_size = 32

lr = 0.0001

max_step = 6000~10000

时函数收敛比较快,训练步数介于6000~10000时,训练出来的效果比较好,之前训练了一晚上100000步,第二天过来看预测结果,满心欢喜看到train_acc达到了100%,但用来识别精度大概在75%,模型出现了严重的过拟合,训练8000步的时候,预测结果很理想,100张网图,基本没有错。

而当lr设太小的话,模型不收敛或收敛的太慢,lr太大的话模型会出现震荡。

下面直接给出训练代码,里面会有详细注释。

代码

import os
import numpy as np
import tensorflow as tf
import test
import model
import timeN_CLASSES = 2
IMG_W = 208
IMG_H = 208
BATCH_SIZE = 16
CAPACITY = 2000 #队列中元素个数
MAX_STEP = 8000
learning_rate = 0.0001 #小于0.001print("I'm OK")
train_dir = 'E:\\Pycharm\\tf-01\\Bigwork\\train\\'  # 训练图片文件夹
logs_train_dir = 'E:\\Pycharm\\tf-01\\Bigwork\\savenet02\\'  # 保存训练结果文件夹train, train_label = test.get_files(train_dir)train_batch, train_label_batch = test.get_batch(train,train_label,IMG_W,IMG_H,BATCH_SIZE,CAPACITY)#训练操作定义
sess = tf.Session()train_logits = model.inference(train_batch, BATCH_SIZE, N_CLASSES)
train_loss = model.losses(train_logits, train_label_batch)
train_op = model.trainning(train_loss, learning_rate)
train_acc = model.evaluation(train_logits, train_label_batch)#train_label_batch = tf.one_hot(train_label_batch,2,1,0)
#测试操作定义summary_op = tf.summary.merge_all()#产生一个writer来写log文件
train_writer = tf.summary.FileWriter(logs_train_dir,sess.graph)
saver = tf.train.Saver()sess.run(tf.global_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord = coord)#加入队列,很重要tra_loss = .0
tra_acc = .0
# val_loss = .0
# val_acc = .0try:start = time.clock()#计算每一个step所花的时间for step in np.arange(MAX_STEP):if coord.should_stop():break_,tra_loss_,tra_acc_ = sess.run([train_op,train_loss,train_acc])# val_loss_, val_acc_ = sess.run([test_loss, test_acc])#下面这一段为我为了打印神经网络最后一层变化写的,可以不要'''train,label = sess.run([train_logits,train_label_batch])#print(train)L = []for i in train:max_ = np.argmax(i)L.append(max_)print(L)print(label)'''tra_loss = tra_loss+tra_loss_tra_acc = tra_acc+tra_acc_# val_loss = val_loss+val_loss_# val_acc = val_acc+val_acc_if (step+1) % 50 == 0 and step!=0:end = time.clock()print('Step %d, train loss = %.2f, train accuracy = %.2f%%' % (step+1, tra_loss/50, tra_acc * 100.0/50))#print('Step %d, val loss = %.2f, val accuracy = %.2f%%' % (step, val_loss/50,val_acc*100.0/50))print(str(end-start))tra_loss = .0tra_acc = .0summary_str = sess.run(summary_op)train_writer.add_summary(summary_str, step)start = time.clock()# 每隔2000步,保存一次训练好的模型if step%2000==0 or step == MAX_STEP-1:checkpoint_path = os.path.join(logs_train_dir, 'model.ckpt')saver.save(sess, checkpoint_path, global_step=step)except tf.errors.OutOfRangeError:print('Done training -- epoch limit reached')finally:coord.request_stop()coord.join(threads)
sess.close()

开始训练

下图我是每间隔10个step打印一次结果,可以看到,训练到2000多步的时候,精度已经可以达到了84%。

下一篇:预测

TensorFlow详解猫狗识别 (三)--训练过程相关推荐

  1. TensorFlow详解猫狗识别(一)--读取自己的数据集

    数据集下载 链接: https://pan.baidu.com/s/1SlNAPf3NbgPyf93XluM7Fg 密码: hpn4 数据集分别有12500张cat,12500张dog 读取数据集 数 ...

  2. Tensorflow实现kaggle猫狗识别(循序渐进进行网络设计)

    这篇是tensorflow版本,pytorch版本会在下一篇博客给出 友情提示:尽量上GPU,博主CPU上跑一个VGG16花了1.5h... Tensorflow实现kaggle猫狗识别 数据集获取 ...

  3. 毕设:基于CNN卷积神经网络的猫狗识别、狗品种识别(Tensorflow、Keras、Kaggle竞赛)

    基于卷积神经网络的图像识别算法及其应用研究 毕业快一年了,拿出来分享给大家,我和网上唯一的区别就是,我能够同时实现两个方案(猫狗识别和狗品种识别),我当时也是网上各种查,花了2,3个月的时间,一个萝卜 ...

  4. 详解pytorch实现猫狗识别98%附代码

    详解pytorch实现猫狗识别98%附代码 前言 一.为什么选用pytorch这个框架? 二.实现效果 三.神经网络从头到尾 1.来源:仿照人为处理图片的流程,模拟人们的神经元处理信息的方式 2.总览 ...

  5. 【卷积神经网络】CNN详解以及猫狗识别实例

    文章目录 一.卷积神经网络(CNN)介绍 1.1 整体结构 1.2 说明 1.3 特点 1.4 应用领域 二.配置实验环境 三.猫狗识别实例 3.1 准备数据集 3.2 图片分类 3.3 网络模型搭建 ...

  6. 深度学习之CNN卷积神经网络详解以及猫狗识别实战

    文章目录 CNN 解决了什么问题? 需要处理的数据量太大 图像简单数字化无法保留图像特征 CNN核心思想 局部感知 参数共享 卷积神经网络-CNN 的基本原理 卷积--提取特征 池化层(下采样)--数 ...

  7. 猫狗大战——基于TensorFlow的猫狗识别(2)

    微信公众号:龙跃十二 我是小玉,一个平平无奇的小天才! 上篇文章我们说了关于猫狗大战这个项目的一些准备工作,接下来,我们看看具体的代码详解. 猫狗大战--基于TensorFlow的猫狗识别(1) 文件 ...

  8. 使用Tensorflow 2进行猫狗分类识别

    使用Tensorflow 2进行猫狗分类识别 本文参照了大佬Andrew Ng的所讲解的Tensorflow 2视频所写,本文将其中只适用于Linux的功能以及只适用于Google Colab的功能改 ...

  9. tensorflow 猫狗识别 数据增强

    文章目录 卷积网络参数 网络配置 模型训练和效果展示 完整代码 数据增强 角度旋转 平移变换 缩放 channel_shift 翻转 rescale变化 图像填充 任务介绍: 有猫和狗的图片,需要对这 ...

最新文章

  1. PMP-【第2章 项目运行环境与项目经理】-2020-12-29(35页-48页)
  2. mysql proxy读写分离实现_使用mysql-proxy实现读写分离
  3. LetCode: 5. 最长回文子串
  4. matlab数据序列的几种滤波器
  5. Oracle数据类型与.NET中的对应关系
  6. 存储过程 SQL server(01)
  7. springcloud(六):配置中心git示例
  8. 【今日CV 视觉论文速览】 Part2 25 Jan 2019
  9. Docker 深入理解概念
  10. C++设计模式之Template Method(模板方法模式)
  11. iOS Swift 2 2 监听耳机的 插拔的事件
  12. Qt Toast 一个淡出提示效果
  13. 安全狗又拿下一场重保胜战 第22届投洽会顺利谢幕
  14. 凯撒加密的python语言程序_python语言编程实现凯撒密码、凯撒加解密算法、
  15. 学习笔记15-L298N
  16. idea 创建一个springboot 项目(hello world)
  17. 什么是蜘蛛统计 有什么作用?
  18. 计算机音乐制作专业美国研究生,美国音乐制作专业研究生六大首选音乐学院
  19. 绝对值不等式(贪心)
  20. 数据库系列-什么是 JDBC?它的作用是什么?

热门文章

  1. 「学习笔记」Vue 官方视频教程 2.0版
  2. 2014-12-20 如果不加班我们就一起去联想玩敏捷吧!!!!
  3. 设置环境变量配置的简单方法.env
  4. Windows CmdPHP窗口消失 但程序依旧执行 查找进程ID(PId)并强制结束进程--wmic process where name--taskkill
  5. 销售宝:软件销售技巧与话术,行业前景分析!
  6. 100天精通Python丨黑科技篇 —— 20、Python 修图(滤镜、灰度、裁剪、视觉处理、图像分割、特征提取)
  7. linux+特殊字符文件夹,linux创建带特殊符号的文件夹
  8. 2亿日活,日均千万级视频上传,快手推荐系统如何应对技术挑战?
  9. vue圆环进度条_Vue/React圆环进度条
  10. webpack5学习笔记-3 打包优化的操作