我们直到学习率对于机器学习来说,大的学习率虽然往往能够使得损失函数快速下降,但是导致不收敛或者振荡现象的发生,而小的学习率虽然收敛,但是学习速率太慢,损失函数下降缓慢,需要等待长时间的训练,同时也会容易陷入局部最优。因此,一种解决方法是令学习率随迭代次数的增加而下降。

下面是python示例。该例子可以参考TensorFlow进阶--实现反向传播博文

这里的关键在于

tf.train.exponential_decay(initial_learning_rate,global_step=global_step,decay_steps=10,decay_rate=0.9)

通过设置:

初始学习率initial_learning_rate

当前迭代次数global_step

每decay_steps更新一次学习率

每次更新乘以0.9,达到指数衰减的效果

代码如下

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
#创建计算图会话
sess = tf.Session()#生成数据并创建在占位符和变量Ax_vals = np.concatenate((np.random.normal(-1,1,50),np.random.normal(3,1,50)))
y_vals = np.concatenate((np.repeat(0.,50),np.repeat(1.,50)))x_data = tf.placeholder(tf.float32,shape=[1])
y_target = tf.placeholder(tf.float32,shape=[1])A = tf.Variable(tf.random_normal(mean=10,shape=[1]))#增加乘法操作
my_output = tf.add(x_data,A)#由于非归一化logits的交叉熵的损失函数期望批量数据增加一个批量数据的维度
my_output_expanded = tf.expand_dims(my_output,0)
y_target_expanded = tf.expand_dims(y_target,0)#增加非归一化logits的交叉熵的损失函数loss = tf.nn.sigmoid_cross_entropy_with_logits( logits=my_output_expanded , labels=y_target_expanded )#声明变量的优化器global_step = tf.Variable(0, trainable=False)initial_learning_rate = 0.5 #初始学习率learning_rate = tf.train.exponential_decay(initial_learning_rate,global_step=global_step,decay_steps=10,decay_rate=0.9)my_opt = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)train_step = my_opt.minimize(loss)#在运行之前,需要初始化变量
init = tf.initialize_all_variables()
sess.run(init)num = 1000
step = np.zeros(num)
LOSS = np.zeros_like(step)
Learning_rate_vec = []
# 训练算法
for i in range(num):rand_index = np.random.choice(100)rand_x = [x_vals[rand_index]]rand_y = [y_vals[rand_index]]sess.run(train_step,feed_dict={x_data:rand_x,y_target:rand_y,global_step:i})#打印step[i]= iLOSS[i] = sess.run(loss,feed_dict={x_data:rand_x,y_target:rand_y})Learning_rate_vec.append(sess.run(learning_rate,feed_dict={global_step:i}))if (i+1)%100 ==0:print('step =' + str(i+1) +' A = '+ str(sess.run(A)))print('loss =' + str(LOSS[i]) )fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(step,LOSS,label='loss')
ax.set_xlabel('step')
ax.set_ylabel('loss')
fig.suptitle('sigmoid_cross_entropy_with_logits')
handles,labels = ax.get_legend_handles_labels()
ax.legend(handles,labels=labels)fig2 = plt.figure()
ax2 = fig2.add_subplot(111)
ax2.plot(Learning_rate_vec,'k-')
ax2.set_xlabel('step')
ax2.set_ylabel('Learning_rate')
fig2.suptitle('Learning_rate')plt.show()# logdir = './log'
# write = tf.summary.FileWriter(logdir=logdir,graph=sess.graph)

TensorFlow进阶--实现学习率随迭代次数下降相关推荐

  1. 每天五分钟机器学习:随着算法迭代次数动态调整学习率

    本文重点 我们使用的学习率往往是不变的,本节课程我们将令学习率随着迭代次数的增加而减小,这会对算法的学习有很大的好处. 好处 当我们运行随机梯度下降时,算法会从某个点开始,然后曲折的逼近最小值,但是不 ...

  2. TensorFlow之二—学习率 (learning rate)

    文章目录 一.分段常数衰减 tf.train.piecewise_constan() 二.指数衰减 tf.train.exponential_decay() 三.自然指数衰减 tf.train.nat ...

  3. 学习率对神经网络迭代次数和准确率的影响以及近似数学表达式

    本文构造了一个带有1个卷积核的网络二分类minst的0和2,通过调整学习率r观察学习率对神经网络的迭代次数和准确率到底有什么影响. 实验用minst数据集将28*28的图片缩小到9*9,网络用一个3* ...

  4. 学习率对神经网络迭代次数的影响

    调整学习率看看对网络的迭代次数是否有影响 学习率先后实验了5, 1, 0.5, 0.3, 0.2, 0.1, 0.01, 0.005, 0.003,0.001 权重的初始化标准是 Random ran ...

  5. TensorFlow中设置学习率的方式

    目录 1. 指数衰减 2. 分段常数衰减 3. 自然指数衰减 4. 多项式衰减 5. 倒数衰减 6. 余弦衰减 6.1 标准余弦衰减 6.2 重启余弦衰减 6.3 线性余弦噪声 6.4 噪声余弦衰减 ...

  6. 梯度下降的线性回归用python_运用TensorFlow进行简单实现线性回归、梯度下降示例...

    线性回归属于监督学习,因此方法和监督学习应该是一样的,先给定一个训练集,根据这个训练集学习出一个线性函数,然后测试这个函数训练的好不好(即此函数是否足够拟合训练集数据),挑选出最好的函数(cost f ...

  7. 用矩阵内积的办法构造迭代次数受控的神经网络1:0.6:0.1=4:3:2

    每个神经网络对应每个收敛标准δ都有一个特征的迭代次数n,因此可以用迭代次数曲线n(δ)来评价网络性能. 一个二分类网络,分类两个对象A和B,B中有K张图片,B的第i张图片被取样的概率为pi,B中第i张 ...

  8. 用矩阵点积的办法构造神经网络的迭代次数1:0.6:0.1=1:1:1

    每个神经网络对应每个收敛标准δ都有一个特征的迭代次数n,因此可以用迭代次数曲线n(δ)来评价网络性能. 一个二分类网络分类两个对象A和B,B中有K张图片,B的第i张图片被取样的概率为pi,B中第i张图 ...

  9. 用神经网络迭代次数曲线模拟原子光谱

    大量实验表明每个神经网络对应每个收敛标准δ都有一个特征的迭代次数n,因此可以用迭代次数曲线n(δ)来评价网络性能. 一个二分类网络分类两组对象A和B,B中有K张图片,B的第i张图片被取样的概率为pi, ...

最新文章

  1. 无序数组及其子序列的相关问题研究
  2. 开源oa_开源OA:可以轻松支持云文档管理
  3. 七天学会NodeJS
  4. mysql外键教程_MySQL外键使用详解
  5. 前端开发学习笔记(二)
  6. 判断一颗二叉树是否是平衡二叉树
  7. 初级数据分析师需要哪些必备技能?
  8. phpcms_v9推送到其他栏目后再在其他栏目删除导致数据库出错
  9. end to end testing
  10. oracle segment undo_71_UNDO扩展学习
  11. 实时的毛发绘制 szlongman
  12. ktv app html,ktv.html
  13. python变量作用域图解_图解python全局变量与局部变量相关知识
  14. oracle spool
  15. 查找 -- 7.1 Sear for a Range -- 图解
  16. y480 linux无线网卡驱动,联想y480无线网卡驱动下载
  17. 技术分享 | Hulu视频广告系统中的算法应用
  18. np问题 量子计算机,P vs NP与经典与量子计算机可解决的问题相同吗?
  19. v.douyin.com/xxx抖音网址官方生成制作抖音缩短口令网址php接口方法
  20. git报错:error.GitError: manifests rev-list (‘^HEAD‘, ‘14686468c69c63f1995ab2a0a9ad90b2e1d5e01c‘, ‘--‘)

热门文章

  1. 个人项目 小跟班——蓝牙小车控制(UI篇)
  2. 面向对象程序设计实践(C++)——二维向量
  3. Compiler - 编译器
  4. 想学linux需要的电脑配置相关
  5. 高鲁棒!高实时!慕尼黑工业大学开源RGB-L SLAM!
  6. 【蓝桥杯每日一练:木头加工】
  7. 对于anaconda安装的一个小感悟 。
  8. 自定义ironic-python-agent镜像 ipa ramdisk and kernel
  9. biopython:基因genbank格式转核酸或氨基酸fasta格式
  10. Java笔记3.1——Java基础之数组