来源 | deeplearning.ai

编译 | 刘静

转载自图灵TOPIA(ID:turingtopia)

初始化对训练深度神经网络的收敛性有重要影响。简单的初始化方案可以加速训练,但是它们需要小心避免常见的陷阱。

近期,deeplearning.ai就如何有效地初始化神经网络参数发表了交互式文章,图灵君将结合这篇文章与您一起探索以下问题:

1、有效初始化的重要性

2、梯度爆炸或消失的问题

3、什么是正确的初始化?

4、Xavier初始化的数学证明

一、有效初始化的重要性

要构建机器学习算法,通常需要定义一个体系结构(例如Logistic回归,支持向量机,神经网络)并训练它来学习参数。 以下是神经网络的常见训练过程:

1、初始化参数

2、选择优化算法

3、重复这些步骤:

a、正向传播输入

b、计算成本函数

c、使用反向传播计算与参数相关的成本梯度

d、根据优化算法,使用梯度更新每个参数

然后,给定一个新的数据点,您可以使用该模型来预测它的类。

初始化步骤对于模型的最终性能至关重要,它需要正确的方法。 为了说明这一点,请考虑下面的三层神经网络。 您可以尝试使用不同的方法初始化此网络,并观察它对学习的影响。

(网址:https://www.deeplearning.ai/ai-notes/initialization/)

感兴趣的同学可直接登陆、操作体验。

当初始化方法为零时,对于梯度和权重,您注意到了什么?

用零初始化所有权重会导致神经元在训练期间学习相同的特征。

实际上,任何常量初始化方案的性能表现都非常糟糕。 考虑一个具有两个隐藏单元的神经网络,并假设我们将所有偏差初始化为0,并将权重初始化为一些常数α。 如果我们在该网络中正向传播输入(x1,x2),则两个隐藏单元的输出将为relu(αx1+αx2)。 因此,两个隐藏单元将对成本具有相同的影响,这将导致相同的梯度。

因此,两个神经元将在整个训练过程中对称地进化,有效地阻止了不同的神经元学习不同的东西。

在初始化权重时,如果值太小或太大,关于成本图,您注意到了什么?

尽管打破了对称性,但是用值(i)太小或(ii)太大来初始化权重分别导致(i)学习缓慢或(ii)发散。

为高效训练选择适当的初始化值是必要的。 我们将在下一节进一步研究。

二、梯度的爆炸或消失问题

考虑这个9层神经网络。

在优化循环的每次迭代(前向,成本,后向,更新)中,我们观察到当您从输出层向输入层移动时,反向传播的梯度要么被放大,要么被最小化。 如果您考虑以下示例,此结果是有意义的。

假设所有激活函数都是线性的(标识函数)。 然后输出激活是:

其中,L=10,W[1],W[2],…,W[L−1] 都是大小为(2,2)的矩阵,因为层[1]到[L-1]有2个神经元,接收2个输入。考虑到这一点,为了便于说明,如果我们假设W[1]=W[2]=⋯=W[L−1]=W,输出预测是y^=W[L]WL−1x (其中  WL−1 将矩阵 W取为L-1的幂,而W[L] 表示Lth矩阵)。

初始化值太小,太大或不合适的结果是什么?

情形1:过大的初始化值会导致梯度爆炸

考虑这样一种情况:初始化的每个权重值都略大于单位矩阵。

这简化为y^=W[L]1.5L−1x,并且a[l] 的值随l呈指数增加。 当这些激活用于反向传播时,就会导致梯度爆炸问题。 也就是说,与参数相关的成本梯度太大。 这导致成本围绕其最小值振荡。

情形2:初始化值太小会导致梯度消失

类似地,考虑这样一种情况:初始化的每个权重值都略小于单位矩阵。

这简化为 y^=W[L]0.5L−1x,并且激活a [l]的值随l呈指数下降。 当这些激活用于反向传播时,这会导致消失的梯度问题。 相对于参数的成本梯度太小,导致在成本达到最小值之前收敛。

总而言之,使用不适当的值初始化权重将导致神经网络训练的发散或减慢。 虽然我们用简单的对称权重矩阵说明了梯度爆炸/消失问题,但观察结果可以推广到任何太小或太大的初始化值。

三、如何找到合适的初始化值

为了防止网络激活的梯度消失或爆炸,我们将坚持以下经验法则:

1、激活的平均值应为零。

2、激活的方差应该在每一层保持不变。

在这两个假设下,反向传播的梯度信号不应该在任何层中乘以太小或太大的值。 它应该移动到输入层而不会爆炸或消失。

更具体地考虑层l, 它的前向传播是:

我们希望以下内容:

确保零均值并保持每层输入方差的值不会产生爆炸/消失信号,我们稍后会解释。 该方法既适用于前向传播(用于激活),也适用于反向传播传播(用于激活成本的梯度)。 推荐的初始化是Xavier初始化(或其派生方法之一),对于每个层l:

换句话说,层l的所有权重是从正态分布中随机选取的,其中均值μ= 0且方差σ2= n [l-1] 1其中n [l-1]是层l-1中的神经元数。 偏差用零初始化。

下面的可视化说明了Xavier初始化对五层全连接神经网络的每个层激活的影响。

您可以在Glorot等人中找到这种可视化背后的理论。(2010年)。 下一节将介绍Xavier初始化的数学证明,并更准确地解释为什么它是一个有效的初始化。

四、Xavier初始化的合理性

在本节中,我们将展示Xavier初始化使每个层的方差保持不变。 我们假设层的激活是正态分布在0附近。 有时候,理解数学原理有助于理解概念,但不需要数学,就可以理解基本思想。

让我们对第(III)部分中描述的层l进行处理,并假设激活函数为tanh。 前向传播是:

目标是导出Var(a [l-1])和Var(a [l])之间的关系。 然后我们将理解如何初始化我们的权重,使得: Var(a[l−1])=Var(a[l])。

假设我们使用适当的值初始化我们的网络,并且输入被标准化。 在训练初期,我们处于tanh的线性状态。 值足够小,因此tanh(z[l])≈z[l],意思是:

此外,z[l]=W[l]a[l−1]+b[l]=向量(z1[l],z2[l],…,zn[l][l])其中 zk[l]=∑j=1n[l−1]wkj[l]aj[l−1]+bk[l]。 为简单起见,我们假设b[l]=0 (考虑到我们将选择的初始化选择,它将最终为真)。 因此,在前面的方程Var(a[l−1])=Var(a[l]) 中逐个元素地看,现在给出:

常见的数学技巧是在方差之外提取求和。 为此,我们必须做出以下三个假设:

1、权重是独立的,分布相同;

2、输入是独立的,分布相同;

3、权重和输入是相互独立的。

因此,现在我们有:

另一个常见的数学技巧是将乘积的方差转化为方差的乘积。公式如下:

使用X=wkj[l]和Y=aj[l−1]的公式,我们得到:

我们差不多完成了! 第一个假设导致E[wkj[l]]2=0,第二个假设导致E[aj[l−1]]2=0,因为权重用零均值初始化,输入被归一化。 从而:

上述等式源于我们的第一个假设,即:

同样,第二个假设导致:

同样的想法:

总结一下,我们有:

瞧! 如果我们希望方差在各层之间保持不变(Var(a[l])=Var(a[l−1])),我们需要Var(W[l])=n[l−1]1。 这证明了Xavier初始化的方差选择是正确的。

请注意,在前面的步骤中,我们没有选择特定的层ll。 因此,我们已经证明这个表达式适用于我们网络的每一层。 让LL成为我们网络的输出层。 在每一层使用此表达式,我们可以将输出层的方差链接到输入层的方差:

根据我们如何初始化权重,我们的输出和输入的方差之间的关系会有很大的不同。 请注意以下三种情况。

因此,为了避免正向传播信号的消失或爆炸,我们必须通过初始化Var(W[l])=n[l−1]1来设置n[l−1]Var(W[l])=1。

在整个证明过程中,我们一直在处理在正向传播期间计算的激活。对于反向传播的梯度也可以得到相同的结果。这样做,您将看到,为了避免梯度消失或爆炸问题,我们必须通过初始化 Var(W[l])=n[l]1来设置n[l]Var(W[l])=1。

结论

实际上,使用Xavier初始化的机器学习工程师会将权重初始化为N(0,n[l−1]1) 或N(0,n[l−1]+n[l]2)。 后一分布的方差项是n [l-1] 1和n [1] 1的调和平均值。

这是Xavier初始化的理论依据。 Xavier初始化与tanh激活一起工作。 还有许多其他初始化方法。 例如,如果您正在使用ReLU,则通常的初始化是He初始化(He et al,Delving Deep into Rectifiers),其中权重的初始化方法是将Xavier初始化的方差乘以2。虽然这种初始化的理由稍微复杂一些,但它遵循与tanh相同的思考过程。

参考链接:

https://www.deeplearning.ai/ai-notes/initialization/

(*本文为 AI科技大本营转载文章,转载请联系原作者)

 

CTA核心技术及应用峰会

5月25-27日,由中国IT社区CSDN与数字经济人才发展中心联合主办的第一届CTA核心技术及应用峰会将在杭州国际博览中心隆重召开,峰会将围绕人工智能领域,邀请技术领航者,与开发者共同探讨机器学习和知识图谱的前沿研究及应用。

更多重磅嘉宾请识别海报二维码查看。目前会议8折预售票抢购中,点击阅读原文即刻抢购。添加小助手微信15101014297,备注“CTA”,了解票务以及会务详情。

推荐阅读

  • 赌5毛钱,你解不出这道Google面试题

  • @程序员,别再自己闷头学了

  • 我用Python,3分钟快速实现,9种经典排序算法的可视化

  • 手把手教你利用爬虫爬网页(Python代码)

  • 云在物联网中的惊人优势 | 技术头条

  • 天才少年,大学创业,29 岁创立 Coinbase!| 人物志

  • 没上过大学,曾拒绝盖茨的 Offer,四代码农靠他吃饭 | 人物志

  • 狂赚320亿! 小伙建立第一个区块链国家, 国土面积7km², 自由之城诞生记

  • 小姐姐公开征婚高智商 IT 男:微信号竟要质数解密?

吴恩达团队:神经网络如何正确初始化?相关推荐

  1. 吴恩达卷积神经网络 笔记,吴恩达 深度神经网络

    如何评价吴恩达的学术地位 吴恩达(AndrewNg),斯坦福计算机系的副教授,师从机器学习的大师级人物MichaelI.Jordan. 同门师兄弟包括ZoubinGhahramani,TommiJaa ...

  2. 吴恩达卷积神经网络笔记,吴恩达人工智能公开课

    吴恩达是个谁 吴恩达(1976-,英文名:AndrewNg),华裔美国人,是斯坦福大学计算机科学系和电子工程系副教授,人工智能实验室主任.吴恩达是人工智能和机器学习领域国际上最权威的学者之一. 吴恩达 ...

  3. 吴恩达团队AI诊断心律失常研究:准确率超人类医生

    2019年,吴恩达团队在AI医疗领域实现了一项革命性的突破,他们成功地让AI诊断心律失常,其准确率高达83.7%,超过了人类心脏病医生的78.0%.这项研究成果已经发表在了知名期刊Nature Med ...

  4. 吴恩达《神经网络和深度学习》第二周编程作业—用神经网络思想实现逻辑回归

    吴恩达<神经网络和深度学习>-用神经网络思想实现逻辑回归 1 安装包 2 问题概述 3 学习算法的一般架构 4 构建算法的各个部分 4.1 激活函数 4.2 初始化参数 4.3 前向和后向 ...

  5. 吴恩达《神经网络和深度学习》第四周编程作业—深度神经网络应用--Cat or Not?

    吴恩达<神经网络和深度学习>- 深度神经网络应用--Cat or Not? 1 安装包 2 数据集 3 模型的结构 3.1 两层神经网络 3.2 L层深度神经网络 3.3 通用步骤 4 两 ...

  6. 吴恩达团队最新成果:用深度学习来改善临终关怀服务

    翻译 | AI科技大本营(ID:rgznai100) 参与 | 尚岩奇,刘畅 AI可以是杀戮的武器,也可以是救世的良方. 上周,在日内瓦举行的联合国特定常规武器公约会议上,伯克利大学教授Stuart ...

  7. 《智源社区周刊:预训练模型》第1期:吴恩达团队医疗影像预训练、快手落地万亿参数模型...

    超大规模预训练模型是当前人工智能领域研究的热点,为了帮助研究与工程人员了解这一领域的进展和资讯,智源社区整理了<智源社区周刊:预训练模型>,从研究动态.行业资讯.热点讨论等几个维度推荐最近 ...

  8. 周刊#003提要:吴恩达团队盘点2019 AI 大事件圣诞 AI 论战

    为了帮助中国人工智能科研.从业者们更好地了解全球人工智能领域的最新资讯,智源研究院编辑团队本周整理.编辑了第3期<智源社区AI周刊>,从学术(论文和新思想分享.最新学术会议等),行业和政策 ...

  9. 吴恩达《神经网络与深度学习》精炼笔记(5)-- 深层神经网络

    上节课我们主要介绍了浅层神经网络.首先介绍神经网络的基本结构,包括输入层,隐藏层和输出层.然后以简单的2 layer NN为例,详细推导了其正向传播过程和反向传播过程,使用梯度下降的方法优化神经网络参 ...

  10. 吴恩达《神经网络与深度学习》精炼笔记(4)-- 浅层神经网络

    上节课我们主要介绍了向量化.矩阵计算的方法和python编程的相关技巧.并以逻辑回归为例,将其算法流程包括梯度下降转换为向量化的形式,从而大大提高了程序运算速度.本节课我们将从浅层神经网络入手,开始真 ...

最新文章

  1. PetClinic 没有分页功能
  2. 【职场】是什么让女性在计算机史上“隐身”了?
  3. python真正实现多线程_python多线程实现
  4. Java PushbackReader ready()方法与示例
  5. C++继承详解三 ----菱形继承、虚继承
  6. 量子计算机迷宫,一个简单的例子,带你读懂量子计算机
  7. jquery select css样式,css配合jquery美化 select
  8. Android系统性能优化(66)---APK启动加速
  9. struts2文件上传类型的过滤
  10. weblogic8.1在myeclipse中启动正常,在单独的weblogic中无法正常启动的解决方案.
  11. 图像识别深度学习主流方案平台比较
  12. 【学习笔记】常见的激励函数和损失函数
  13. Spring Boot 微信点餐系统
  14. 异步通信在生活中的例子_通信技术在日常生活中的作用
  15. 学习笔记(4):零基础掌握 Python 入门到实战-深入浅出字符串(二)
  16. 关于总线、现场总线、RS-485和modbus之间的关系
  17. 问题分析:5W2H分析法
  18. ubuntu 16.04 deepin.com.wechat 微信登录提示版本过低解决方案
  19. Android命令-重点命令-pm/am/content/wm/appops
  20. 六、v8引擎执行JS文件

热门文章

  1. urlparse模块(专门用来解析URL格式)
  2. 关于 android 加载 res 图片 out of memory 问题 解决 同样适用于 sd卡图片
  3. php中magic_quotes_gpc对unserialize的影响
  4. getElementById 不能取得visible=false 的控件解决方法
  5. ISO9000机房管理办法
  6. Eigen矩阵运算的混淆问题
  7. MySQL存储引擎的介绍
  8. SQL数据库无法附加 系统表损坏修复 数据库中病毒解密恢复
  9. String、StringBuilder、StringBuffer的比较
  10. 查询数据库所有表、字段、触发器等