本来我想用tensorbaord来观察LOSS曲线变化的,但是我代码改得不对,如果有小伙伴改出来了,如果可以的话可以告诉我,我懒得改了。下面代码也是注意改成自己的路径

# 导入文件
import os
import numpy as np
import tensorflow as tf
import input_data
import model
import os
import time
import warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
warnings.filterwarnings('ignore')# 变量声明
N_CLASSES = 5  # 五种花类型
IMG_W = 64  # resize图像,太大的话训练时间久
IMG_H = 64
BATCH_SIZE = 25
CAPACITY = 250
MAX_STEP =5000
learning_rate = 0.0005# 一般小于0.0001train_dir = 'D:/flower_photos/input_data2/train'  # 训练样本的读入路径
val_dir = 'D:/flower_photos/input_data2/val'  # 验证样本的读入路径
logs_train_dir = 'D:/save2/train'  # logs存储路径
logs_val_dir = 'D:/save2/val'train, train_label= input_data.get_files(train_dir)
val, val_label = input_data.get_files(val_dir)
# 训练数据及标签
train_batch, train_label_batch = input_data.get_batch(train, train_label, IMG_W, IMG_H, BATCH_SIZE, CAPACITY)
# 测试数据及标签
val_batch, val_label_batch = input_data.get_batch(val, val_label, IMG_W, IMG_H, BATCH_SIZE, CAPACITY)x = tf.placeholder(tf.float32, shape=[BATCH_SIZE, IMG_W, IMG_H, 3])
y_ = tf.placeholder(tf.int32, shape=[BATCH_SIZE])# 存放DropOut参数的容器,训练时为0.45,测试时为0
dropout_placeholdr = tf.placeholder(tf.float32)
# 是否是训练状况
train = tf.placeholder(tf.float32)logits = model.inference(x, BATCH_SIZE, N_CLASSES,dropout_placeholdr,train)
loss = model.losses(logits, y_)
acc = model.evaluation(logits, y_)
train_op = model.trainning(loss, learning_rate)with tf.Session() as sess:saver = tf.train.Saver()sess.run(tf.global_variables_initializer())coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=sess, coord=coord)summary_op = tf.summary.merge_all()train_writer = tf.summary.FileWriter(logs_train_dir, sess.graph)val_writer = tf.summary.FileWriter(logs_val_dir)# val_writer = tf.summary.FileWriter(logs_val_dir, sess.graph)try:for step in np.arange(MAX_STEP):if coord.should_stop():breaktra_images, tra_labels = sess.run([train_batch, train_label_batch])_, tra_loss, tra_acc = sess.run([train_op, loss, acc],feed_dict={x: tra_images, y_: tra_labels,dropout_placeholdr:0.45,train:1})if step % 100 == 0:print('Step %d, train loss = %.2f, train accuracy = %.2f%%' % (step, tra_loss, tra_acc * 100.0))summary_str = sess.run(summary_op, feed_dict={x: tra_images, y_: tra_labels,dropout_placeholdr:0.45,train:1})train_writer.add_summary(summary_str, step)if step % 100 == 0:val_images, val_labels = sess.run([val_batch, val_label_batch])val_loss, val_acc = sess.run([loss, acc],feed_dict={x: val_images, y_: val_labels,dropout_placeholdr:1.0,train:0})print('** val loss = %.2f, val accuracy = %.2f%%  **' % (val_loss, val_acc * 100.0))summary_str = sess.run(summary_op, feed_dict={x: tra_images, y_: tra_labels,dropout_placeholdr:1.0,train:0})val_writer.add_summary(summary_str, step)if step % 100 == 0 or (step + 1) == MAX_STEP:checkpoint_path = os.path.join(logs_train_dir, 'model.ckpt')saver.save(sess, checkpoint_path, global_step=step)except tf.errors.OutOfRangeError:print('Done training -- epoch limit reached')finally:coord.request_stop()coord.join(threads)

其中save文件夹中存储的就是训练好的模型,这个在后面测试的时候会用到。

图像分类之花卉图像分类(四)训练模型相关推荐

  1. 图像分类之花卉图像分类(一)数据增强

    网上有很多图像分类的代码,有很多是必须要在GPU上面才能跑的,因为我想在自己的电脑跑,所以很多都是不能用的,而且说实话很多对我这个小白来说,都很难看懂.所以我找了一个就是之间用CNN写的神经卷积模型用 ...

  2. 图像分类之花卉图像分类(五)测试数据

    这个时候测试集就要用到了,为了便于观察,我们这里先给测试集重命名,这样子,哪张图片分类错了,我们也比较好找. 重命名代码: #!/usr/bin/python # -*- coding:utf-8 - ...

  3. 【实验课程】花卉图像分类实验

    转载地址:https://bbs.huaweicloud.com/forum/thread-80033-1-1.html 作者:yangyaqin 实验介绍 随着电子技术的迅速发展,人们使用便携数码设 ...

  4. 【项目实战课】基于Pytorch的InceptionNet花卉图像分类实战

    欢迎大家来到我们的项目实战课,本期内容是<基于Pytorch的InceptionNet花卉图像分类实战>.所谓项目课,就是以简单的原理回顾+详细的项目实战的模式,针对具体的某一个主题,进行 ...

  5. 基于SIFT和颜色特征的花卉图像分类

    基于SIFT特征和颜色直方图的花卉图像分类 摘 要 课程实验提取图片的SIFT特征,通过k均值聚类的方法将所有训练图片的特征聚类为800类,以每个类出现的频率构建特征向量(又称为词袋模型),同时,通过 ...

  6. 【图像分类】 关于图像分类中类别不平衡那些事

    欢迎大家来到图像分类专栏,类别不平衡时是很常见的问题,本文介绍了类别不平衡图像分类算法的发展现状,供大家参考学习. 作者&编辑 | 郭冰洋 1 简介 小伙伴们在利用公共数据集动手搭建图像分类模 ...

  7. 基于深度学习模型的花卉图像分类代码_实战 | 基于深度学习模型VGG的图像识别(附代码)...

    本文演示了如何使用百度公司的PaddlePaddle实现基于深度学习模型VGG的图像识别. 准备工作 VGG简介 牛津大学VGG(Visual Geometry Group)组在2014年ILSVRC ...

  8. 基于深度学习模型的花卉图像分类代码_华为不止有鸿蒙!教你快速入门华为免编程深度学习神器ModelArts...

    引言: 本文介绍利用华为ModelArts进行深度学习的图像分类任务,不用一行代码. 今年8月9日,在华为史上规模最大的开发者大会上,华为正式发布全球首个基于微内核的全场景分布式OS--鸿蒙操作系统( ...

  9. python图像分类教程_TensorFlow图像分类教程

    深度学习算法与计算机硬件性能的发展,使研究人员和企业在图像识别.语音识别.推荐引擎和机器翻译等领域取得了巨大的进步.六年前,视觉模式识别领域取得了第一个超凡的成果.两年前,Google大脑团队开发了T ...

最新文章

  1. SQL Server Alwayson 主从数据库账号同步
  2. 畅想来自未来的便携扫描仪
  3. @autowired注解原理_SpringBoot注解大全,收藏一波!!!
  4. php验证百度云doc,百度云推送通知埋的大大的坑,成功测试REST API for PHP服务端...
  5. 手机两列布局,正方形
  6. acwing——每日一题——总结
  7. c语言fun函数yx,C语言解24点游戏程序
  8. 基于CSE的微服务架构实践-Spring Boot技术栈选型
  9. Python对象特殊方法及其用法演示
  10. C++ machine code与随机数 进阶习题
  11. cmd下特殊字符串的处理
  12. 下标超出数量 oracle,超出最大会话数和Ora-00020超出最大进程数错误的解决方法...
  13. 依赖于 !important 标签是个危险的现象。奔驰车如何查4S店的保养记录
  14. 2015中国大学排行榜100强新鲜出炉(校友会版)-[转]
  15. STC89C52引脚图
  16. 使用Scratch制作打弹球游戏(一)弹球游戏过关
  17. 俄勒冈之旅_俄勒冈州波特兰市严格禁止面部识别技术
  18. 【strlen函数的使用及strlen函数的三种模拟实现】· C语言详解库函数篇(一)
  19. 特征重要性计算方法及神经网络的特征重要性
  20. ROS:一种路径优化方法-拉直法

热门文章

  1. python漫画滤镜怎么实现的_OpenCV图片漫画效果的实现示例
  2. 《Linux网络管理应用 大学笔记 》- 初学者 - 用户和组的管理
  3. Java中字符串与整数的相互转换
  4. 518. 零钱兑换 II
  5. 加载动态库失败(loadLibrary返回为空)的几种解决办法
  6. linux内核看门狗关闭方法,linux内核中断之看门狗
  7. 如何用c#制作QQ农场外挂
  8. gentoo linux图形界面,Gentoo桌面系统的安装
  9. 深度学习训练中GPU占用0%
  10. 云原生爱好者周刊:好家伙,Rust 也成立云原生组织了