TensorFlow进阶--实现学习率随迭代次数下降
我们直到学习率对于机器学习来说,大的学习率虽然往往能够使得损失函数快速下降,但是导致不收敛或者振荡现象的发生,而小的学习率虽然收敛,但是学习速率太慢,损失函数下降缓慢,需要等待长时间的训练,同时也会容易陷入局部最优。因此,一种解决方法是令学习率随迭代次数的增加而下降。
下面是python示例。该例子可以参考TensorFlow进阶--实现反向传播博文
这里的关键在于
tf.train.exponential_decay(initial_learning_rate,global_step=global_step,decay_steps=10,decay_rate=0.9)
通过设置:
初始学习率initial_learning_rate
当前迭代次数global_step
每decay_steps更新一次学习率
每次更新乘以0.9,达到指数衰减的效果
代码如下
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
#创建计算图会话
sess = tf.Session()#生成数据并创建在占位符和变量Ax_vals = np.concatenate((np.random.normal(-1,1,50),np.random.normal(3,1,50)))
y_vals = np.concatenate((np.repeat(0.,50),np.repeat(1.,50)))x_data = tf.placeholder(tf.float32,shape=[1])
y_target = tf.placeholder(tf.float32,shape=[1])A = tf.Variable(tf.random_normal(mean=10,shape=[1]))#增加乘法操作
my_output = tf.add(x_data,A)#由于非归一化logits的交叉熵的损失函数期望批量数据增加一个批量数据的维度
my_output_expanded = tf.expand_dims(my_output,0)
y_target_expanded = tf.expand_dims(y_target,0)#增加非归一化logits的交叉熵的损失函数loss = tf.nn.sigmoid_cross_entropy_with_logits( logits=my_output_expanded , labels=y_target_expanded )#声明变量的优化器global_step = tf.Variable(0, trainable=False)initial_learning_rate = 0.5 #初始学习率learning_rate = tf.train.exponential_decay(initial_learning_rate,global_step=global_step,decay_steps=10,decay_rate=0.9)my_opt = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)train_step = my_opt.minimize(loss)#在运行之前,需要初始化变量
init = tf.initialize_all_variables()
sess.run(init)num = 1000
step = np.zeros(num)
LOSS = np.zeros_like(step)
Learning_rate_vec = []
# 训练算法
for i in range(num):rand_index = np.random.choice(100)rand_x = [x_vals[rand_index]]rand_y = [y_vals[rand_index]]sess.run(train_step,feed_dict={x_data:rand_x,y_target:rand_y,global_step:i})#打印step[i]= iLOSS[i] = sess.run(loss,feed_dict={x_data:rand_x,y_target:rand_y})Learning_rate_vec.append(sess.run(learning_rate,feed_dict={global_step:i}))if (i+1)%100 ==0:print('step =' + str(i+1) +' A = '+ str(sess.run(A)))print('loss =' + str(LOSS[i]) )fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(step,LOSS,label='loss')
ax.set_xlabel('step')
ax.set_ylabel('loss')
fig.suptitle('sigmoid_cross_entropy_with_logits')
handles,labels = ax.get_legend_handles_labels()
ax.legend(handles,labels=labels)fig2 = plt.figure()
ax2 = fig2.add_subplot(111)
ax2.plot(Learning_rate_vec,'k-')
ax2.set_xlabel('step')
ax2.set_ylabel('Learning_rate')
fig2.suptitle('Learning_rate')plt.show()# logdir = './log'
# write = tf.summary.FileWriter(logdir=logdir,graph=sess.graph)
TensorFlow进阶--实现学习率随迭代次数下降相关推荐
- 每天五分钟机器学习:随着算法迭代次数动态调整学习率
本文重点 我们使用的学习率往往是不变的,本节课程我们将令学习率随着迭代次数的增加而减小,这会对算法的学习有很大的好处. 好处 当我们运行随机梯度下降时,算法会从某个点开始,然后曲折的逼近最小值,但是不 ...
- TensorFlow之二—学习率 (learning rate)
文章目录 一.分段常数衰减 tf.train.piecewise_constan() 二.指数衰减 tf.train.exponential_decay() 三.自然指数衰减 tf.train.nat ...
- 学习率对神经网络迭代次数和准确率的影响以及近似数学表达式
本文构造了一个带有1个卷积核的网络二分类minst的0和2,通过调整学习率r观察学习率对神经网络的迭代次数和准确率到底有什么影响. 实验用minst数据集将28*28的图片缩小到9*9,网络用一个3* ...
- 学习率对神经网络迭代次数的影响
调整学习率看看对网络的迭代次数是否有影响 学习率先后实验了5, 1, 0.5, 0.3, 0.2, 0.1, 0.01, 0.005, 0.003,0.001 权重的初始化标准是 Random ran ...
- TensorFlow中设置学习率的方式
目录 1. 指数衰减 2. 分段常数衰减 3. 自然指数衰减 4. 多项式衰减 5. 倒数衰减 6. 余弦衰减 6.1 标准余弦衰减 6.2 重启余弦衰减 6.3 线性余弦噪声 6.4 噪声余弦衰减 ...
- 梯度下降的线性回归用python_运用TensorFlow进行简单实现线性回归、梯度下降示例...
线性回归属于监督学习,因此方法和监督学习应该是一样的,先给定一个训练集,根据这个训练集学习出一个线性函数,然后测试这个函数训练的好不好(即此函数是否足够拟合训练集数据),挑选出最好的函数(cost f ...
- 用矩阵内积的办法构造迭代次数受控的神经网络1:0.6:0.1=4:3:2
每个神经网络对应每个收敛标准δ都有一个特征的迭代次数n,因此可以用迭代次数曲线n(δ)来评价网络性能. 一个二分类网络,分类两个对象A和B,B中有K张图片,B的第i张图片被取样的概率为pi,B中第i张 ...
- 用矩阵点积的办法构造神经网络的迭代次数1:0.6:0.1=1:1:1
每个神经网络对应每个收敛标准δ都有一个特征的迭代次数n,因此可以用迭代次数曲线n(δ)来评价网络性能. 一个二分类网络分类两个对象A和B,B中有K张图片,B的第i张图片被取样的概率为pi,B中第i张图 ...
- 用神经网络迭代次数曲线模拟原子光谱
大量实验表明每个神经网络对应每个收敛标准δ都有一个特征的迭代次数n,因此可以用迭代次数曲线n(δ)来评价网络性能. 一个二分类网络分类两组对象A和B,B中有K张图片,B的第i张图片被取样的概率为pi, ...
最新文章
- 无序数组及其子序列的相关问题研究
- 开源oa_开源OA:可以轻松支持云文档管理
- 七天学会NodeJS
- mysql外键教程_MySQL外键使用详解
- 前端开发学习笔记(二)
- 判断一颗二叉树是否是平衡二叉树
- 初级数据分析师需要哪些必备技能?
- phpcms_v9推送到其他栏目后再在其他栏目删除导致数据库出错
- end to end testing
- oracle segment undo_71_UNDO扩展学习
- 实时的毛发绘制 szlongman
- ktv app html,ktv.html
- python变量作用域图解_图解python全局变量与局部变量相关知识
- oracle spool
- 查找 -- 7.1 Sear for a Range -- 图解
- y480 linux无线网卡驱动,联想y480无线网卡驱动下载
- 技术分享 | Hulu视频广告系统中的算法应用
- np问题 量子计算机,P vs NP与经典与量子计算机可解决的问题相同吗?
- v.douyin.com/xxx抖音网址官方生成制作抖音缩短口令网址php接口方法
- git报错:error.GitError: manifests rev-list (‘^HEAD‘, ‘14686468c69c63f1995ab2a0a9ad90b2e1d5e01c‘, ‘--‘)