回顾一下梯度下降的过程:

假设当前神经网络有以下参数 θ = { ω 1 , ω 2 , . . . , b 1 , b 2 , . . . } \theta = \{\omega_1,\omega_2,...,b_1,b_2,...\} θ={ω1​,ω2​,...,b1​,b2​,...},那么梯度下降就是计算损失函数对于每个参数的梯度,然后按照梯度更新公式来更新每一个参数。但在深度学习中参数量巨大,这样计算时间过长,因此反向传播就是来高效就计算出损失函数对于每个参数的梯度的。注意反向传播并不是一个和梯度下降不同的训练方法,它只是能够更有效率就计算出损失函数对参数的梯度,来帮助梯度下降过程。

反向传播

损失函数可以如下表示:
L ( θ ) = ∑ n = 1 N C n ( θ ) L(\theta)=\sum_{n=1}^N C^n(\theta) L(θ)=n=1∑N​Cn(θ)
其中 C n ( θ ) C^n(\theta) Cn(θ)表示第n个样本的输出值和理想值之间的距离。那么:
∂ L ( θ ) ∂ w = ∑ n = 1 N ∂ C n ( θ ) ∂ w \frac{\partial L(\theta)}{\partial w}=\sum_{n=1}^N \frac{\partial C^n(\theta)}{\partial w} ∂w∂L(θ)​=n=1∑N​∂w∂Cn(θ)​
也就是将总体损失对参数的微分转换成每一个样本的距离对参数的微分的求和

假设对于图上网络:
∂ C ∂ w = ∂ z ∂ w ∂ C ∂ z \frac{\partial C}{\partial w}=\frac{\partial z}{\partial w}\frac{\partial C}{\partial z} ∂w∂C​=∂w∂z​∂z∂C​
其中:

  • ∂ z ∂ w \frac{\partial z}{\partial w} ∂w∂z​:称为前向传播(Forward pass),较为容易计算
  • ∂ C ∂ z \frac{\partial C}{\partial z} ∂z∂C​:称为反向前进(Backward pass),较难计算

Forward pass

从上图中我们可以很简单地算出
∂ z ∂ w 1 = x 1 ∂ z ∂ w 2 = x 2 \frac{\partial z}{\partial w_1}=x_1\\ \frac{\partial z}{\partial w_2} = x_2 ∂w1​∂z​=x1​∂w2​∂z​=x2​
也就是说对于每条边或者说每个参数,它所连接的下一层的输入对于该参数的求导就等于上一层在这条边上的输入,例如下图:

所以前向传播这一步可以很简单的计算出来。

Bcakward pass

现在需要来考虑如何计算 ∂ C ∂ z \frac{\partial C}{\partial z} ∂z∂C​,假设前述z经过一个Sigmoid函数后得到a,那么a作为下一层神经网的某一个输入,因此就可以写出:
∂ C ∂ z = ∂ a ∂ z ∂ C ∂ a \frac{\partial C}{\partial z}=\frac{\partial a}{\partial z}\frac{\partial C}{\partial a} ∂z∂C​=∂z∂a​∂a∂C​
而从上图中也可以很清楚地看到**可以用微积分的知识转换成上述公式,而其中对a的求导也可以结合我们上述的知识很容易的求解。因此现在就是如何求解C对两个z的求导了。

但假设我们当前能够通过某种方法知道了C对两个z的求导,同时我们将网络进行些许转换,如下:

根据那个公式我们可将网络反向过来,这有助于待会理解反向传播。不过值得注意的是此处神经元结点对于输入加权和后是乘上 σ ‘ ( z ) \sigma`(z) σ‘(z),在z确定的时候(当输入确定时z就确定了)可以看成常数,因此跟正向神经网络的非线性变化不同

继续计算C对两个z的求导:

情况一

假设 z ‘ z` z‘和 z ‘ ‘ z`` z‘‘经过非线性变换后已经就是输出了,那么这种简单的情况可以很简单的写出上面的计算式,也就很简单的完成了我们对于参数梯度的计算工作。其中
∂ C ∂ y 取决于你的损失函数 ∂ y ∂ z 取决于最后一层的非线性变换 \frac{\partial C}{\partial y}取决于你的损失函数\\ \frac{\partial y}{\partial z}取决于最后一层的非线性变换 ∂y∂C​取决于你的损失函数∂z∂y​取决于最后一层的非线性变换
情况二

假设 z ‘ z` z‘和 z ‘ ‘ z`` z‘‘后面仍然有很多未知的线性变化,但通过前述的讲解我们可以明确只要知道了下一层的C对各个z的求导,那么就一定可以算出当前层C对各个z的求导。因此只要不断地往后推,找到某一层的z经过非线性变换后就是输出,那么就可以计算C对该层的z的求导(情况一),然后再往前推直到C对每一层的z的求导都算出来

那么在实际上的做法就是:

  1. 建立一个反向的神经网络,其结构相同权重参数相同,但是功能神经元结点的非线性变换变成了常数,就是之前的 σ ‘ ( z ) \sigma`(z) σ‘(z),这需要先计算Forwardpass之后才可以计算(其中还需要计算 ∂ z ∂ w \frac{\partial z}{\partial w} ∂w∂z​)
  2. 计算损失函数C对最后一层的每个z的求导,那么它们就是这个反向神经网络的输入参数
  3. 再根据网络的不断传播就可以计算出最终结果

这就是反向传播

【机器学习】李宏毅——何为反向传播相关推荐

  1. 【李宏毅机器学习】backpropagation 反向传播(p13) 学习笔记

    李宏毅机器学习学习笔记汇总 课程链接 文章目录 Gradient Descent Chain Rule链式法则 前向传播 反向传播 情况一:红色的neural是属于网络的output layer的 情 ...

  2. 李宏毅机器学习课程7~~~反向传播

    到底为什么基于反向传播的纯监督学习在过去表现不佳?Geoffrey Hinton总结了目前发现的四个方面问题: 带标签的数据集很小,只有现在的千分之一. 计算性能很慢,只有现在的百万分之一. 权重的初 ...

  3. 李宏毅机器学习back propogation反向传播

    回顾梯度下降: 有一堆参数 θ=w1,w2,...,b1,b2,...\theta= {w_1,w_2,...,b_1,b_2,... }θ=w1​,w2​,...,b1​,b2​,... 首先由一个 ...

  4. 吴恩达机器学习:神经网络 | 反向传播算法

    上一周我们学习了 神经网络 | 多分类问题.我们分别使用 逻辑回归 和 神经网络 来解决多分类问题,并了解到在特征数非常多的情况下,神经网络是更为有效的方法.这周的课程会给出训练 神经网络 所使用的 ...

  5. 神经网络与机器学习 笔记—改善反向传播的性能试探法

    改善反向传播的性能试探法 整理8个能提高BP效率的方法: 随机和批量方式更新 反向传播学习的随机(串行)方式(涉及一个模式接一个模式的更新)要比批量方式 计算快.特别是当新联数据集很大且高度冗余时,更 ...

  6. 机器学习笔记:反向传播

    1 共享路径 对于每一个可学习的参数w,我们都需要通过梯度下降更新它的值.而我们可以使用反向传播的思路来求解损失函数在某一个参数上的loss 如上图所示,可以用如下方式运算 可以发现 ,上图红色圈起来 ...

  7. 【机器学习】P18 反向传播(导数、微积分、链式法则、前向传播、后向传播流程、神经网络)

    反向传播 反向传播 反向传播中的数学 导数与python 链式法则 简单神经网络处理流程从而理解反向传播 神经网络与前向传播 神经网络与反向传播 反向传播 反向传播(back propagation) ...

  8. 大脑模拟AI学习策略,这项逼近反向传播的研究登上《自然-神经科学》

    来源:机器之心 编辑:陈萍.杜伟 大脑是如何近似反向传播算法的?发表在<自然 - 神经科学>的一篇论文,研究者找到了可以生活在活体大脑并进行实时工作的等价物,他们提出的大脑学习算法模型可以 ...

  9. 反向传播算法中,逐级向前计算delta公式的由来

    学习吴恩达机器学习第九章反向传播编程练习,有这样一个公式: 它的证明在博客反向传播算法(过程及公式推导)中可以找到 https://www.cnblogs.com/wlzy/p/7751297.htm ...

最新文章

  1. Glibc辅助运行库 (C RunTime Library): crt0.o,crt1.o,crti.o crtn.o,crtbegin.o crtend.o
  2. 物理学家解说2012
  3. Elasticsearch forceMerge操作
  4. element 时间日期选择器el-date-picker点击清空按钮报错 Cannot read property ‘0‘ of null
  5. Spring MVC:会话高级
  6. 记一次吐血的ping: unknown host
  7. [NOIP2013] 花匠
  8. Go语言中的异常和错误处理简介
  9. 目前我们所说的个人商用计算机属于,计算机组成原理试题库(大专生用,共20份,有答案)...
  10. 布鲁斯的秘密-序章:我是布鲁斯
  11. tomcat 6.0配置
  12. 几个小伙伴的进入信息安全行业的经历
  13. ENVI操作:监督分类
  14. FTP工具,5款常用的FTP工具
  15. python语句分号_Python中的分号
  16. android app 经纬度,经纬度定位app
  17. 游戏数据库 TcaplusDB
  18. 【概率论】范畴分布 Categorical / Multinoulli Distribution
  19. VMware环境部署vFW虚拟防火墙
  20. mm_struct(内存描述符)

热门文章

  1. Unity 动画逆播放
  2. [推荐算法]UserCF,基于用户的协同过滤算法
  3. 【Error: error:0308010C:digital envelope routines::unsupported】
  4. 小米新财报:手机承压,转型求生
  5. 一机双屏,海信A2如何引领创新大潮?
  6. CES 软件的新平台与新机遇
  7. 产品经理须知:机会成本和沉没成本
  8. 【Python】用 Python 的 scipy 包实现线性规划(LP)
  9. jQuery中获取兄弟元素的方法
  10. Why Ruby is Simple