我是从keras入门深度学习的,第一个用的demo是keras实现的yolov3,代码很好懂(其实也不是很好懂,第一次也搞了很久才弄懂)

然后是做的车牌识别,用了tiny-yolo来检测车牌位置,当时训练有4w张图片,用了一天来训练,当时觉得时间可能就是这么长,也不懂GPU训练的时候GPU利用率,所以不怎么在意,后来随着项目图片片的增多,训练时间越来越大,受不了了,看了一片文章才注意到GPU利用率的问题.想到要用tensorflow原生的api去训练,比如用tf.data.dataset

就找到了这个tensorflow原生实现yolo的项目,在训练的时候发现他没加梯度衰减,训练了一段时间total loss下不去了,所以加了一个梯度衰减。想写一下文章,小白的第一篇文章哈哈哈,大神别喷我的内容太简单

YunYang1994/tensorflow-yolov3​github.com

他好像改了train.py

原来是这样的

import tensorflow as tf
from core import utils, yolov3
from core.dataset import dataset, Parser
sess = tf.Session()IMAGE_H, IMAGE_W = 416, 416
BATCH_SIZE       = 8
EPOCHS           = 2000*1000
LR               = 0.0001
SHUFFLE_SIZE     = 1000
CLASSES          = utils.read_coco_names('./data/voc.names')
ANCHORS          = utils.get_anchors('./data/voc_anchors.txt')
NUM_CLASSES      = len(CLASSES)train_tfrecord   = "../VOC/train/voc_train*.tfrecords"
test_tfrecord    = "../VOC/test/voc_test*.tfrecords"parser   = Parser(IMAGE_H, IMAGE_W, ANCHORS, NUM_CLASSES)
trainset = dataset(parser, train_tfrecord, BATCH_SIZE, shuffle=SHUFFLE_SIZE)
testset  = dataset(parser, test_tfrecord , BATCH_SIZE, shuffle=None)is_training = tf.placeholder(tf.bool)
example = tf.cond(is_training, lambda: trainset.get_next(), lambda: testset.get_next())images, *y_true = example
model = yolov3.yolov3(NUM_CLASSES, ANCHORS)with tf.variable_scope('yolov3'):y_pred = model.forward(images, is_training=is_training)loss = model.compute_loss(y_pred, y_true)optimizer = tf.train.AdamOptimizer(LR)
saver = tf.train.Saver(max_to_keep=2)tf.summary.scalar("loss/coord_loss",   loss[1])
tf.summary.scalar("loss/sizes_loss",   loss[2])
tf.summary.scalar("loss/confs_loss",   loss[3])
tf.summary.scalar("loss/class_loss",   loss[4])write_op = tf.summary.merge_all()
writer_train = tf.summary.FileWriter("./data/train")
writer_test  = tf.summary.FileWriter("./data/test")update_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="yolov3/yolo-v3")
with tf.control_dependencies(update_var):train_op = optimizer.minimize(loss[0], var_list=update_var,global_step=global_step) # only update yolo layersess.run(tf.global_variables_initializer())
pretrained_weights = tf.global_variables(scope="yolov3/darknet-53")
load_op = utils.load_weights(var_list=pretrained_weights,weights_file="./darknet53.conv.74")
sess.run(load_op)for epoch in range(EPOCHS):run_items = sess.run([train_op, write_op] + loss, feed_dict={is_training:True})writer_train.add_summary(run_items[1], global_step=epoch)writer_train.flush() # Flushes the event file to diskif (epoch+1)%1000 == 0: saver.save(sess, save_path="./checkpoint/yolov3.ckpt", global_step=epoch)run_items = sess.run([write_op] + loss, feed_dict={is_training:False})writer_test.add_summary(run_items[0], global_step=epoch)writer_test.flush() # Flushes the event file to diskprint("EPOCH:%7d tloss_xy:%7.4f tloss_wh:%7.4f tloss_conf:%7.4f tloss_class:%7.4f"%(epoch, run_items[2], run_items[3], run_items[4], run_items[5]))

然后我发现没有梯度下降,所以就找了怎么实现

实现如下
optimizer = tf.train.AdamOptimizer(LR)
改为
global_step = tf.Variable(0, trainable=False)
learning_rate = tf.train.exponential_decay(LR,100,0.93,staircase=True,global_step=global_step)
optimizer = tf.train.AdamOptimizer(learning_rate)

learningrate 是梯度的类,LR是初始梯度,100是每一百次初始梯度乘以衰减度,这里是第三个参数0.93代表了衰减度,globalstep_step = global_step是一定要加的,不然梯度一直保持了初始梯度。

最后加个打印

tf.summary.scalar('learning_rate',learning_rate)

就可以爽快的去训练了

一文简单弄懂tensorflow_在tensorflow中设置梯度衰减相关推荐

  1. 一文简单弄懂tensorflow_【TensorFlow】一文弄懂CNN中的padding参数

    在深度学习的图像识别领域中,我们经常使用卷积神经网络CNN来对图像进行特征提取,当我们使用TensorFlow搭建自己的CNN时,一般会使用TensorFlow中的卷积函数和池化函数来对图像进行卷积和 ...

  2. 一文彻底弄懂大端与小端

    一文彻底弄懂大端与小端 1. 端模式起源 端模式(Endian)起源于<格列佛游记>, 书中根据鸡蛋敲开的方式不同将所有人分为2类,从圆头开始敲的人被归为Big Endian,从尖头开始敲 ...

  3. ​Cookie 从入门到进阶:一文彻底弄懂其原理以及应用

    大家好,我是若川.持续组织了8个月源码共读活动,感兴趣的可以点此加我微信 ruochuan12 参与,每周大家一起学习200行左右的源码,共同进步.同时极力推荐订阅我写的<学习源码整体架构系列& ...

  4. 从原理的视角,一文彻底弄懂FPGA的查找表(LUT)、CLB

    我学东西有个特点,喜欢从原理的层面彻底弄懂一个知识点,这几天想弄明白FPGA的查找表,但发现很多博文写的很模糊,看了以后仍然不是很明白.当然,可能是作者自己弄懂了,但没有站在新人的角度来详细的解释.通 ...

  5. TensorFlow中设置学习率的方式

    目录 1. 指数衰减 2. 分段常数衰减 3. 自然指数衰减 4. 多项式衰减 5. 倒数衰减 6. 余弦衰减 6.1 标准余弦衰减 6.2 重启余弦衰减 6.3 线性余弦噪声 6.4 噪声余弦衰减 ...

  6. 一文让你完全弄懂回归问题、激活函数、梯度下降和神经元模型实战《繁凡的深度学习笔记》第 2 章 回归问题与神经元模型(DL笔记整理系列)

    <繁凡的深度学习笔记>第 2 章 回归问题与神经元模型(DL笔记整理系列) 3043331995@qq.com https://fanfansann.blog.csdn.net/ http ...

  7. 干货:一文彻底弄懂递归如何解题

    前言 递归是算法中一种非常重要的思想,应用也很广,小到阶乘,再在工作中用到的比如统计文件夹大小,大到 Google 的 PageRank 算法都能看到,也是面试官很喜欢的考点 最近看了不少递归的文章, ...

  8. RabbitMq(二)一文彻底弄懂RabbitMq的四种交换机原理及springboot实战应用

    四大交换机工作原理及实战应用 交换机概念 direct 直连交换机 工作模式图解 springboot代码 Fanout扇出交换机 工作模式图解 springboot代码 Topic主题交换机 工作模 ...

  9. 一文彻底弄懂工厂模式(Factory)

    文章已收录我的仓库:Java学习笔记与免费书籍分享 模式类型 工厂模式属于创建者模式,与对象的创建有关,其中工厂方法模式用于类,而抽象工厂模式用于对象.创建型类模式将对象的部分创建工作延迟到子类,由子 ...

最新文章

  1. RMAN 备份SHELL
  2. 如何限制对象只能建立在堆上或者栈上
  3. qt创建右键菜单,显示在鼠标点击处
  4. 单点登录 cas 设置回调地址_单点登录落地实现技术有哪些,有哪些流行的登录方案搭配?...
  5. maven deploy distributionManagement
  6. 【AI面试题】Softmax的原理是什么,有什么作用
  7. 领导:“请在今晚进行网络系统升级”
  8. 二维数组 : 旋转矩阵
  9. request+BeautifulSoup:下载《笔趣看》网小说《第九特区》
  10. 【云原生监控系列第一篇】一文详解Prometheus普罗米修斯监控系统(山前前后各有风景,有风无风都很自由)
  11. linux录制声卡声音_linux下ALSA声卡 录音问题
  12. 线程与进程的区别,举个例子让你快速理解
  13. 华人数学家破译孪生素数猜想 影响或超1+2证明
  14. SQL---日期时间函数
  15. CCF A类 B类 C类 中国计算机学会推荐中文科技期刊目录【迷惑了好久】
  16. 超声波液位计测量原理及应用领域
  17. 【Android 实现VideoView开始和播放时缓冲监听动画(监听播放状态)】
  18. 当手机淘宝遇见折叠屏,让购物更随心
  19. js调用摄像头拍照并访问后端代码
  20. vivox80和vivox80pro有什么区别 哪个值得买

热门文章

  1. 虚构合同、虚开发票套取高校配套科研经费,一副教授被公诉!
  2. 本科、硕士、博士之间的差距!
  3. 一位合格的博士生需要有哪些条件和素质?
  4. 推荐系统--矩阵分解(1)
  5. mysql 第二天数据_MySQL入门第二天------数据库操作
  6. mysql left join、right join、inner join、union、union all使用以及图解
  7. 阿里云何川:开放兼容的云,计算巢帮助合作伙伴云化升级
  8. 当微服务遇上 Serverless | 微服务容器化最短路径,微服务 on Serverless 最佳实践
  9. 瓜子二手车在 Dubbo 版本升级、多机房方案方面的思考和实践
  10. 分布式系统:一致性协议