最近基于VGG-16缩进了网络做了一个CNN模型用于处理图像分类,实际项目,训练对象是448×32的长条试纸图片。

目录

项目源码百度云

tensorboard可视化展示

源代码



项目源码百度云

项目源码百度云链接:https://pan.baidu.com/s/1aWLeh4Kaft7NPlB0GxBZMg 
提取码:vjhu

里面项目名字没改,VGG16因为是改造的,名字也没好好取,能用就行。。

model 储存模型文件,为了方便下载,已经删除了,可以自己训练
logs 存放日志文件,已经删除,本文后续有图片展示tensorboard日志
data

数据文件格式如上右图,test1和train2为原始图片,test和train为处理后图片,统一为448×32大小,用于网络训练。每个下有lh1、lh2等为类别,每个类别下分别存放了图片

注:test/lh1,test1/lh1,train/lh1,train2/lh1下各有一张图片供参考,数据就不大量泄漏了,其余文件为空

VGG16_RAW.py VGG16模型源文件
VGG16_mini2.py 改造的小型CNN模型,训练出来模型大概7M左右
tf009_predition.py 预测文件
tf009.py 训练文件
image_pre_deal.py 图像预处理文件,将原始图片转为统一大小格式的图片
calculate_mean.py 计算图片平均值的文件,用于后续减均值处理

tensorboard可视化展示

tensorboard:(不要在意波折的细节,只要看清准确率,损失值数值就行了,采用的Mini-batch的迭代方式,这次训练有点乱)

模型:


源代码

tf009.py源码,训练主文件:(写的迁移学习不要介意,懒得改了。。)

'''
VGG16迁移学习训练主函数
tensorboard --logdir=D:\python\vgg16\logs'''import tensorflow as tf
import os
os.environ["CUDA_VISIBLE_DEVICES"]="-1"  # 由于出现显卡内存不足问题,所以。。。
import numpy as np
from time import time
import vgg16.VGG16_mini2 as modeldef get_batch(image_list,label_list,img_width,img_height,batch_size,capacity):#通过读取列表来载入批量图片及标签image = tf.cast(image_list,tf.string)label = tf.cast(label_list,tf.int32)input_queue = tf.train.slice_input_producer([image,label],shuffle=True)label = input_queue[1]image_contents = tf.read_file(input_queue[0])image = tf.image.decode_jpeg(image_contents,channels=3)image = tf.cast(image,tf.float32)image -= [42.79902,42.79902,42.79902] # 减均值# image = preprocess_for_train(image,img_height,img_width)image.set_shape((img_height,img_width,3))image_batch,label_batch = tf.train.batch([image,label],batch_size=batch_size,num_threads=64,capacity=capacity)label_batch = tf.reshape(label_batch,[batch_size])return image_batch,label_batchdef get_file(file_dir):images = []for root,sub_folders,files in os.walk(file_dir):for name in files:images.append(os.path.join(root,name))labels = []for label_name in images:letter = label_name.split("\\")[-2]if letter =="lh1":labels.append(0)elif letter =="lh2":labels.append(1)elif letter == "lh3":labels.append(2)elif letter == "lh4":labels.append(3)elif letter == "lh5":labels.append(4)elif letter == "lh6":labels.append(5)elif letter == "lh7":labels.append(6)print("check for get_file:",images[0],"label is ",labels[0])#shuffletemp = np.array([images,labels])temp = temp.transpose()np.random.shuffle(temp)image_list = list(temp[:,0])label_list = list(temp[:,1])label_list = [int(float(i)) for i in label_list]return image_list,label_list#标签格式重构
def onehot(labels):n_sample = len(labels)n_class = 7  # max(labels) + 1onehot_labels = np.zeros((n_sample,n_class))onehot_labels[np.arange(n_sample),labels] = 1return onehot_labelsif __name__ == '__main__':startTime =time()batch_size = 8record_epoch = 70000/batch_sizesmall_loop = int(7000/batch_size)capacity = 256  # 内存中存储的最大数据容量pic_height,pic_width = 32,448   # 修改图片大小参数,应当为32的倍数!不然会导致错误xs,ys = get_file('./data/train')#获取图像列表与标签列表image_batch,label_batch = get_batch(xs,ys,img_width=pic_width,img_height=pic_height,batch_size=batch_size,capacity=capacity)# 验证集xs_val,ys_val = get_file('./data/test')#获取图像列表与标签列表image_val_batch,label_val_batch = get_batch(xs_val,ys_val,img_width=pic_width,img_height=pic_height,batch_size=455,capacity=capacity)x = tf.placeholder(tf.float32,[None,pic_height,pic_width,3])y = tf.placeholder(tf.int32,[None,7])#7分类vgg = model.vgg16(x)fc8_fineuining = vgg.probs #即softmax(fc8)prediction_out = tf.argmax(fc8_fineuining,1)real_out = tf.argmax(y,1)correct_prediction = tf.equal(prediction_out,real_out)#检查预测类与实际类别是否匹配accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))#准确率loss_function = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=fc8_fineuining,labels=y))#损失函数optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001).minimize(loss_function)sess = tf.Session()sess.run(tf.global_variables_initializer())# vgg.load_weights('vgg16_weights.npz',sess)saver = tf.train.Saver()# # 断点续训# ckpt_dir = "./model/"# ckpt = tf.train.latest_checkpoint(ckpt_dir)# if ckpt != None:#     saver.restore(sess, ckpt)#     print('saver restore finish')# else:#     print("training from scratch")#启动线程coord = tf.train.Coordinator()#使用协调器管理线程threads = tf.train.start_queue_runners(coord=coord,sess=sess)# 日志记录summary_writer = tf.summary.FileWriter('./logs/', graph=sess.graph, flush_secs=15)summary_writer2 = tf.summary.FileWriter('./logs/plot2/', flush_secs=15)tf.summary.scalar(name='loss_func', tensor=loss_function)tf.summary.scalar(name='accuracy', tensor=accuracy)merged_summary_op = tf.summary.merge_all()epoch_start_time = time()# 采用Mini-batch迭代step = 0epoch = 10000for i in range(epoch):for j in range(small_loop):images,labels = sess.run([image_batch,label_batch])labels = onehot(labels)# # 可视化# plt.subplot(221)# plt.imshow(images[0, :, :, 0])# plt.show()# print(1)sess.run(optimizer,feed_dict={x:images,y:labels})merged_summary,loss,real_train_out = sess.run([merged_summary_op,loss_function,real_out],feed_dict={x:images,y:labels})summary_writer.add_summary(merged_summary, global_step=step)# print(i,j,"to see train data:",real_train_out[:10])step += 1images_val, labels_val = sess.run([image_val_batch, label_val_batch])labels_val = onehot(labels_val)merged_summary_val, loss_val,accuracy_val,prediction_val_out,real_val_out = sess.run([merged_summary_op, loss_function,accuracy,prediction_out,real_out], feed_dict={x: images_val, y: labels_val})summary_writer2.add_summary(merged_summary_val, global_step=step)# 输出每个类别正确率lh1_right, lh2_right, lh3_right, lh4_right, lh5_right, lh6_right, lh7_right = 0, 0, 0, 0, 0, 0, 0lh1_wrong, lh2_wrong, lh3_wrong, lh4_wrong, lh5_wrong, lh6_wrong, lh7_wrong = 0, 0, 0, 0, 0, 0, 0for ii in range(len(prediction_val_out)):if prediction_val_out[ii] == real_val_out[ii]:if real_val_out[ii] == 0:lh1_right+=1elif real_val_out[ii] == 1:lh2_right+=1elif real_val_out[ii] == 2:lh3_right += 1elif real_val_out[ii] == 3:lh4_right += 1elif real_val_out[ii] == 4:lh5_right += 1elif real_val_out[ii] == 5:lh6_right += 1elif real_val_out[ii] == 6:lh7_right += 1else:if real_val_out[ii] == 0:lh1_wrong+=1elif real_val_out[ii] == 1:lh2_wrong+=1elif real_val_out[ii] == 2:lh3_wrong += 1elif real_val_out[ii] == 3:lh4_wrong += 1elif real_val_out[ii] == 4:lh5_wrong += 1elif real_val_out[ii] == 5:lh6_wrong += 1elif real_val_out[ii] == 6:lh7_wrong += 1print(i,"correct rate :",((lh1_right)/(lh1_right+lh1_wrong)),((lh2_right)/(lh2_right+lh2_wrong)),((lh3_right)/(lh3_right+lh3_wrong)),((lh4_right)/(lh4_right+lh4_wrong)),((lh5_right)/(lh5_right+lh5_wrong)),((lh6_right)/(lh6_right+lh6_wrong)),((lh7_right)/(lh7_right+lh7_wrong)))# print(i,"nums:",((lh1_right+lh1_wrong)),(lh2_right+lh2_wrong),((lh3_right+lh3_wrong)),((lh4_right+lh4_wrong)),(lh5_right+lh5_wrong),((lh6_right+lh6_wrong)),((lh7_right+lh7_wrong)))print(i,"epoch's accuracy:",accuracy_val)print(i," loss is %f"%loss,"val loss is %f"%loss_val)epoch_end_time =time()print(i," epoch takes:",(epoch_end_time-epoch_start_time))epoch_start_time = epoch_end_timeif i % 1 == 0 and i != 0:saver.save(sess,os.path.join("./model/",'epoch{:06d}.ckpt'.format(i)))print("------------model saved")# print("-------------Epoch %d is finished"%i)summary_writer.close()saver.save(sess,"./model/")print("optimization finished")duration = time() - startTimeprint("train takes:","{:.2f}".format(duration))coord.request_stop()#通知线程关闭coord.join(threads)#等其他线程关闭这一函数才返回

CNN图像分类(实际项目,特殊训练集,95%准确率,数据代码百度云)相关推荐

  1. 训练集山准确率高测试集上准确率很低_推荐算法改版前的AB测试

    编辑导语:所谓推荐算法就是利用用户的一些行为,通过一些数学算法,推测出用户可能喜欢的东西:如今很多软件都有这样的操作,对于此系统的设计也会进行测试:本文作者分享了关于推荐算法改版前的AB测试,我们一起 ...

  2. 训练集山准确率高测试集上准确率很低_拒绝DNN过拟合,谷歌准确预测训练集与测试集泛化差异,还开源了数据集 | ICLR 2019...

    鱼羊 发自 凹非寺 量子位 报道 | 公众号 QbitAI 深度神经网络(DNN)如今已经无处不在,从下围棋到打星际,DNN已经渗透到图像识别.图像分割.机器翻译等各种领域,并且总是表现惊艳. 然而, ...

  3. keras训练模型,训练集的准确率很高,但是测试集准确率很低的原因

    今天在测试模型时发现一个问题,keras训练模型,训练集准确率很高,测试集准确率很低,因此记录一下希望能帮助大家也避坑: 首先keras本身不同的版本都有些不同的或大或小的bug,包括之前也困扰过我的 ...

  4. java点击车次显示详情_Web项目专项训练——火车车次信息管理系统代码分享

    题目:火车车次信息管理 一.语言和环境 A.实现语言 Java B.环境要求 JDK1.7.Eclipse.Tomcat7.*.mysql 二.功能要求 使用JSP+Servlet实现火车车次信息管理 ...

  5. Web项目专项训练——火车车次信息管理系统代码分享

    题目:火车车次信息管理 一.语言和环境 A.实现语言 Java B.环境要求 JDK1.7.Eclipse.Tomcat7.*.mysql 二.功能要求 使用JSP+Servlet实现火车车次信息管理 ...

  6. 中国工业经济论文合集-含全部数据代码(2015-2021年)

    1. 数据来源:自主制作(约8.6G) 2. 指标说明: <中国工业经济>创办于1984年,由中国社会科学院主管.中国社会科学院工业经济研究所主办,是中国产业经济.企业管理领域的权威学术期 ...

  7. Python批量处理表格有用吗_python批量读入图片、处理并批量输出(可用于深度学习训练集的制作)...

    最近工作实在是太忙了,白浪花的项目没有及时跟进,很多知识也没有自学.好了,趁着现在等领导回复微信的时间,我把上周趁着零散时间做的工作总结一下.内容依然小白,但是却很重要. 项目情况简单描述一下,最终要 ...

  8. 什么是训练集、验证集和测试集?

    在机器学习中,训练集.验证集和测试集是数据集的三个重要部分,用于训练.评估和测试机器学习模型的性能.它们的定义和作用如下: 什么是训练集.验证集和测试集? 训练集:训练集是机器学习模型用于训练和学习的 ...

  9. 一文看懂 AI 训练集、验证集、测试集(附:分割方法+交叉验证)

    2019-12-20 20:01:00 数据在人工智能技术里是非常重要的!本篇文章将详细给大家介绍3种数据集:训练集.验证集.测试集. 同时还会介绍如何更合理的讲数据划分为3种数据集.最后给大家介绍一 ...

最新文章

  1. Django博客系统(404页面展示)
  2. 实战:OpenVINO+OpenCV 文本检测与识别
  3. 【OO学习】OO第四单元作业总结及OO课程总结
  4. python爬虫论文摘要怎么写_Python爬虫基础教学(写给入门的新手)
  5. 【C#】VS2012+InstallShield2013制作软件更新包
  6. Kinect SDK v1.7 新特性、交互框架与新概念
  7. C语言实实验步骤,C语言教程学习方法攻略
  8. .NET Core SDK在Windows系统安装后出现Failed to load the hostfxr.dll等问题的解决方法
  9. 实验7 BindService模拟通信
  10. 音视频学习系列第(四)篇---视频的采集预览
  11. MySQL数据库应用与开发答案_MySQL数据库应用与开发习题解答与上机指导
  12. Delphi开发工具的使用
  13. Linux中history命令增加时间显示
  14. MySql基础知识(高频面试题)
  15. 《定风波》--苏轼之我最喜欢的一首词
  16. vue 生成二维码工具
  17. 现实中的算法面试题(已拿Offer)赚到了,赚到了
  18. 安卓银行木马新增“keylogger”功能,攻击能力倍增
  19. CSDN文章markdown图片居中以及调整大小(超级简单)
  20. JVM---Java内存屏障和JMM

热门文章

  1. java输入输出流、字符字节流
  2. 初等数论重要公式总结
  3. ffmpeg视频压缩命令
  4. 全球计算机科学研究生排名,新|美国计算机科学研究生专业世界排名靠50强名单...
  5. excel根据数据画饼状图等
  6. 达人评测 联想拯救者 Y9000K 2021怎么样
  7. DDN收购Intel Lustre系统业务,详解Lustre系统架构、配置和调优
  8. 2p C和3p C的区别
  9. 营销策划方案示范文本
  10. 步进电机控制,RPM与PPS单位关系分析