文章目录

  • 优化不合理
    • 解决方法1(泰勒展开)
    • 解决方法2(改变batch)
    • 解决方法3(设置momentum)
  • 模型震荡
    • 解决办法1(Adagrad方法)
    • 解决办法2(RMSProp方法)
    • 解决办法3:(Adam)
  • 优化训练损失函数

优化不合理

现象(梯度很小):
1.模型loss基本不变(梯度消失)
2.模型的loss最后收敛很高(陷入局部最优)

梯度消失原因:
1.陷入鞍点(saddle point)
2.陷入局部极值点(最大值或最小值)

解决方法1(泰勒展开)

判断梯度消失是由哪种原因引起的:
由于机器学习模型的函数特别复杂,为了简化计算,使用hessian矩阵进行近似。
假设对于模型的参数θ,设θ‘=θ+Δθ,在θ处进行泰勒展开,保留前三项,则可以得到下列式子
其中,g为L在θ处的梯度,H为Hessian矩阵,计算方法如下:
泰勒展开如下:
当陷入局部极值点时,g趋向于0,可以忽略不计,则L(θ)和L(θ’)的差异取决于在红色框框柱的式子。
此时就会出现三种情况:
对于所有的θ’
1.对于∀θ,(θ−θ′)TH(θ−θ′)>0,则为局部最小值1.对于∀θ,(θ-θ')^TH(θ-θ')>0,则为局部最小值1.对于∀θ,(θ−θ′)TH(θ−θ′)>0,则为局部最小值
2.对于∀θ,(θ−θ′)TH(θ−θ′)<0,则为局部最大值2.对于∀θ,(θ-θ')^TH(θ-θ')<0,则为局部最大值2.对于∀θ,(θ−θ′)TH(θ−θ′)<0,则为局部最大值
3.对于∀θ,(θ−θ′)TH(θ−θ′)>0或(θ−θ′)TH(θ−θ′)<0,则为鞍点3.对于∀θ,(θ-θ')^TH(θ-θ')>0 或(θ-θ')^TH(θ-θ')<0,则为鞍点3.对于∀θ,(θ−θ′)TH(θ−θ′)>0或(θ−θ′)TH(θ−θ′)<0,则为鞍点
但是由于无法穷举所有的θ’,所以需要用到一个数学结论:
eigen value:特征值(下图部分关于最大最小值的判断有误)

推导如下:
可以把(θ-θ’)看成特征向量的集合,即u是特征向量。

这样就将穷举θ转化为了,求负的特征值对应的特征向量,然后根据
θ−θ′=uθ-θ'=uθ−θ′=u
得出θ,对θ’进行更新。

解决方法2(改变batch)

设置batch,这里有两张图
大的batch收敛平稳,训练速度快,但是往往在测试集上表现差;小的batch收敛噪音大,训练速度慢,但是往往在测试集上表现好。

解决方法3(设置momentum)

momentum(动量)
考虑物理世界中,如果一个小球从高处沿着斜坡滑下,当他遇到局部最低点的时候,由于具有动量(惯性),他会继续往前冲一段路,试图越过前一个坡。
改进梯度下降方式:
momentumn+1=λ∗momentumn−η∗Gradientnmomentum^{n+1}=λ*momentum^n-η*Gradient^nmomentumn+1=λ∗momentumn−η∗Gradientn
θn+1=θn+momentumn+1θ^{n+1}=θ^n+momentum^{n+1}θn+1=θn+momentumn+1
(我感觉有点赌,动量的前提是认为翻过这个山能实现更好的效果,但是实际上不一定,可能翻过这个山反而效果反而差)

模型震荡

现象:
1.loss不变,但gradient仍很大

解决办法1(Adagrad方法)

当训练含有两个参数的模型时,如果学习率太大,则会反复震荡,如果学习率太小,则在后期训练缓慢。

所以需要根据gradient来自适应学习率,当gradient大的时候,学习率应该小,gradient小的时候,学习率应该大。
学习率应该更新如下图红框所示:

t表示epoch的次数,i表示为哪一个参数。第t个epoch的σ计算如下:
σit=1t+1∑1≤t≤ng(it)2σ^t_i=\sqrt{\frac{1}{t+1}\sum_{\mathclap{1\le t\le n}} g(^t_i)^2} σit​=t+11​1≤t≤n​∑​g(it​)2​
上述算法在Adagrad优化技术中用到。
但是上述方法的缺陷是:把每一个梯度视为同等重要,所以可能后期调整速度较慢。

解决办法2(RMSProp方法)

计算公式同上,红框中第t个epoch的σ计算方式变化为:
σit=α(σit−1)2+(1−α)(git)2σ^t_i=\sqrt{α(σ^{t-1}_i)^2+(1-α)(g^t_i)^2} σit​=α(σit−1​)2+(1−α)(git​)2​
将不同时间段产生的梯度考虑给予不同的权值,越早产生的gradient的权值越低,可以提高学习率的调整速度。

解决办法3:(Adam)

Adam=RMSProp+Momentum
相当于参数更新函数变为如下所示:
σit=α(σt−1)2+(1−α)(gt)2σ^t_i=\sqrt{α(σ^{t-1})^2+(1-α)(g^t)^2} σit​=α(σt−1)2+(1−α)(gt)2​
momentumn+1=λ∗momentumn−ησn∗Gradientnmomentum^{n+1}=λ*momentum^n-\frac {η}{σ^n}*Gradient^nmomentumn+1=λ∗momentumn−σnη​∗Gradientn
θn+1=θn+momentumn+1θ^{n+1}=θ^n+momentum^{n+1}θn+1=θn+momentumn+1
简单理解的话,相当于在梯度更新速度上加了一个权值,并且在梯度更新方向上加了一个权值

优化训练损失函数

如果在分类问题中,使用MES和交叉熵损失函数,则使用MSE时,会因为MSE在分类问题的求导时,求导项含有梯度分之一,所以在梯度很大时,MES对于分类问题train不起来。
注:pytorch中,cross函数默认会加上softmax,所以不需要在网络的最后一层加上softmax。

神经网络训练失败原因总结相关推荐

  1. 机器学习经验总结-神经网络训练失败的一些常见原因

    前言 在面对模型不收敛的时候,首先要保证训练的次数够多.在训练过程中,loss并不是一直在下降,准确率一直在提升的,会有一些震荡存在.只要总体趋势是在收敛就行.若训练次数够多(一般上千次,上万次,或者 ...

  2. parallels desktop网络初始化失败_秘籍在手,训练不愁!特斯拉AI负责人Karpathy的超全神经网络训练套路...

    大数据文摘出品 编译:周素云.宋欣仪.熊琰.ZoeY.顾晨波 训练神经网络到底有诀窍和套路吗? Andrej Karpathy认为,还的确有. 这位特斯拉的人工智能研究负责人.李飞飞的斯坦福高徒刚刚难 ...

  3. 神经网络训练loss不下降原因集合

    train loss与test loss结果分析 train loss 不断下降,test loss不断下降,说明网络仍在学习; train loss 不断下降,test loss趋于不变,说明网络过 ...

  4. 一文让你掌握神经网络训练技巧

    神经网络训练是一个非常复杂的过程,在这过程中,许多变量之间相互影响,因此我们研究者在这过程中,很难搞清楚这些变量是如何影响神经网络的.而本文给出的众多tips就是让大家,在神经网络训练过程中,更加简单 ...

  5. 特斯拉AI主管Andrej Karpathy的神经网络训练指导

    Andrej Karpathy是目前全球顶尖的计算机视觉专家,博士毕业于斯坦福,现就任于特斯拉,担任AI部门主管.在他2019年4月发布的最新博文中,他跟大家分享了他关于训练神经网络的一些心得.本篇就 ...

  6. Pytorch学习记录-torchtext和Pytorch的实例( 使用神经网络训练Seq2Seq代码)

    Pytorch学习记录-torchtext和Pytorch的实例1 0. PyTorch Seq2Seq项目介绍 1. 使用神经网络训练Seq2Seq 1.1 简介,对论文中公式的解读 1.2 数据预 ...

  7. 陈键飞:基于随机量化的高效神经网络训练理论及算法

    [专栏:前沿进展]随着预训练模型参数规模的增长,所需的算力也不断增加,从算法层面研究和处理模型规模的增长成为研究者关注的话题.近期举办的Big Model Meetup第二期活动,特邀清华大学助理教授 ...

  8. 深度学习与计算机视觉系列(8)_神经网络训练与注意点

    深度学习与计算机视觉系列(8)_神经网络训练与注意点 作者:寒小阳  时间:2016年1月.  出处:http://blog.csdn.net/han_xiaoyang/article/details ...

  9. 学界 | 数据并行化对神经网络训练有何影响?谷歌大脑进行了实证研究

    选自arXiv 作者:Christopher J. Shallue 等 机器之心编译 参与:路.张倩 近期的硬件发展实现了前所未有的数据并行化,从而加速神经网络训练.利用下一代加速器的最简单方法是增加 ...

最新文章

  1. python【力扣LeetCode算法题库】11-盛最多水的容器
  2. [蓝桥杯][2018年第九届真题]调手表(BFS)
  3. setInterval、setTimeout
  4. vb还是python强大-VB强大还是python强大
  5. CarPlay Wireless 使用fdk_aac库解码Raw AAC-LC AAC-ELD
  6. app下载 微信扫码打开 提示用户用浏览器打开
  7. icem二维非结构网格划分_【史上最全轴承结构化网格划分系列】第五弹——自动校准滚针轴承(文末附模型领取方式)...
  8. 中国省市县信息JS文件(省--市--县)
  9. 图片分享和加载失败的原因之一
  10. Vs code PIO一直loading
  11. bzoj5294: [Bjoi2018]二进制(线段树)
  12. 如何用Flutter实现跨平台移动开发
  13. 机器学习、数据挖掘、统计建模的技术担当,20款免费预测分析软件
  14. 两种电致发光器件EQE测量方法(光分布法和积分球法)
  15. html 评分星级显示,星级评分效果.html
  16. python 字体颜色改变
  17. 基于mbedtls的AES加密(C/C++)
  18. linux svn 忽略指定文件
  19. TerraBuilder 操作制作MPT
  20. Android面试攻略

热门文章

  1. 看看女程序媛们的自述
  2. 深入理解ActiveMQ支持的2类消息发送接收模型queue和topic
  3. 丛林木马(数学 思维
  4. ADS1.2 Error:(Fatal) L6002u:could not open file C:/...
  5. ati jti jwt 和_JWT jti和kid属性的说明
  6. Firebase token认证 “kid“ invalid, unable to lookup correct key
  7. 【福利时刻】阿里云盘内测码来啦,ITValuer专属暗号点这里
  8. ygomobile卡组下载网站_游戏王YGOMobile
  9. 从零开始搭建个人大数据集群——环境准备篇
  10. IDEA 控制台窗口双击最大化