之前一直和小伙伴探讨batch normalization层的实现机理,作用在这里不谈,知乎上有一篇paper在讲这个,链接

这里只探究其具体运算过程,我们假设在网络中间经过某些卷积操作之后的输出的feature map的尺寸为4×3×2×2

4为batch的大小,3为channel的数目,2×2为feature map的长宽

整个BN层的运算过程如下图

上图中,batch size一共是4, 对于每一个batch的feature map的size是3×2×2

对于所有batch中的同一个channel的元素进行求均值与方差,比如上图,对于所有的batch,都拿出来最后一个channel,一共有4×4=16个元素,

然后求区这16个元素的均值与方差(上图只求了mean,没有求方差。。。),

求取完了均值与方差之后,对于这16个元素中的每个元素进行减去求取得到的均值与方差,然后乘以gamma加上beta,公式如下

所以对于一个batch normalization层而言,求取的均值与方差是对于所有batch中的同一个channel进行求取,batch normalization中的batch体现在这个地方

batch normalization层能够学习到的参数,对于一个特定的channel而言实际上是两个参数,gamma与beta,对于total的channel而言实际上是channel数目的两倍。

用pytorch验证上述想法是否准确,用上述方法求取均值,以及用batch normalization层输出的均值,看看是否一样

上代码

 1 # -*-coding:utf-8-*-
 2 from torch import nn
 3 import torch
 4
 5 m = nn.BatchNorm2d(3)  # bn设置的参数实际上是channel的参数
 6 input = torch.randn(4, 3, 2, 2)
 7 output = m(input)
 8 # print(output)
 9 a = (input[0, 0, :, :]+input[1, 0, :, :]+input[2, 0, :, :]+input[3, 0, :, :]).sum()/16
10 b = (input[0, 1, :, :]+input[1, 1, :, :]+input[2, 1, :, :]+input[3, 1, :, :]).sum()/16
11 c = (input[0, 2, :, :]+input[1, 2, :, :]+input[2, 2, :, :]+input[3, 2, :, :]).sum()/16
12 print('The mean value of the first channel is %f' % a.data)
13 print('The mean value of the first channel is %f' % b.data)
14 print('The mean value of the first channel is %f' % c.data)
15 print('The output mean value of the BN layer is %f, %f, %f' % (m.running_mean.data[0],m.running_mean.data[0],m.running_mean.data[0]))
16 print(m)

m = nn.BatchNorm2d(3)

声明新的batch normalization层,用

input = torch.randn(4, 3, 2, 2)

模拟feature map的尺寸

输出值

咦,怎么不一样,貌似差了一个小数点,可能与BN层的momentum变量有关系,在生命batch normalization层的时候将momentum设置为1试一试

m.momentum=1

输出结果

没毛病

至于方差以及输出值,大抵也是这样进行计算的吧,留个坑

转载于:https://www.cnblogs.com/yongjieShi/p/9332655.html

Pytorch中的Batch Normalization操作相关推荐

  1. 【学习笔记】Pytorch深度学习—Batch Normalization

    [学习笔记]Pytorch深度学习-Batch Normalization Batch Normalization概念 `Batch Normalization ` `Batch Normalizat ...

  2. 深度神经网络中的Batch Normalization介绍及实现

    之前在经典网络DenseNet介绍_fengbingchun的博客-CSDN博客_densenet中介绍DenseNet时,网络中会有BN层,即Batch Normalization,在每个Dense ...

  3. Tensorflow BatchNormalization详解:4_使用tf.nn.batch_normalization函数实现Batch Normalization操作...

    使用tf.nn.batch_normalization函数实现Batch Normalization操作 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献 吴恩达deeplearnin ...

  4. Pytorch中scatter与gather操作

    文章目录 数据发散scatter 带聚集的发散scatter_add_ onnx中scatterND 数据聚集gather 数据发散scatter 函数原型pytorch官方文档scatter_: s ...

  5. 深入理解pytorch中计算图的inplace操作

    a=1 print(id(a)) a=2 print(id(a)) 并不是在1的空间删除填上2,而是新开辟了空间. a=[1] print(id(a[0])) a[0]=1 print(id(a[0] ...

  6. 神经网络中使用Batch Normalization 解决梯度问题

    BN本质上解决的是反向传播过程中的梯度问题. 详细点说,反向传播时经过该层的梯度是要乘以该层的参数的,即前向有: 那么反向传播时便有: 那么考虑从l层传到k层的情况,有: 上面这个 便是问题所在.因为 ...

  7. Batch Normalization原理及pytorch的nn.BatchNorm2d函数

    下面通过举个例子来说明Batch Normalization的原理,我们假设在网络中间经过某些卷积操作之后的输出的feature map的尺寸为4×3×2×2,4为batch的大小,3为channel ...

  8. 五个角度解释深度学习中 Batch Normalization为什么效果好?

    https://www.toutiao.com/a6699953853724361220/ 深度学习模型中使用Batch Normalization通常会让模型得到更好表现,其中原因到底有哪些呢?本篇 ...

  9. Batch Normalization函数详解及反向传播中的梯度求导

    摘要 本文给出 Batch Normalization 函数的定义, 并求解其在反向传播中的梯度 相关 配套代码, 请参考文章 : Python和PyTorch对比实现批标准化Batch Normal ...

最新文章

  1. 第二课 壳的介绍以及脱壳常用思路
  2. 取生产订单状态的逻辑
  3. 量化交易(图文版其二)
  4. python类定义中__init__(),在__init__中定义一个成员以在python中的类体中定义它的区别?...
  5. mongodb mysql数据类型_mongodb中数据类型的坑
  6. ajax 用户验证js,js ajax验证用户名
  7. ui-router 路由重定向
  8. 黑马程序员传智播客 正则表达式学习笔记 匹配单个字符多个字符
  9. unity物体四种移动方法总结
  10. java汉字拼音简码_java生成首字母拼音简码的总结
  11. 什么是等级保护, 等保2.0详解
  12. 2018 06 01 第一次博客 自然语言处理
  13. 新浪微博API错误代码大全
  14. 树莓派Pico开发板MicroPython嵌入pioasm汇编混合编程技术实践
  15. 万年历源代码 c语言基础,C语言万年历的源程序
  16. 荣耀绽放 | 白玉兰酒店荣膺金光奖“中国发展潜力酒店品牌”奖项
  17. 【CSS】相对长度单位 绝对长度单位,vw/vh , rem等
  18. 要不要启用苹果wapi_苹果“史上最强”系统ios13来了,要不要升级?
  19. linux下ping提示dup,ping出现dup问题
  20. 石化能源行业工业互联网智能工厂解决方案

热门文章

  1. 实施文档_建设工程监理全套资料范本,Word文档附百份案例表格,超实用
  2. LeetCode算法入门- Compare Version Numbers -day14
  3. 系统如何启动数据库服务器,怎么启动sql数据库服务器
  4. 底层实现_Java AOP的底层实现原理
  5. php atlas,apache atlas是什么
  6. php mysql 排名_mysql中如何实现排名
  7. oracle面向对象的数据类型,Oracle面向对象编程OOP
  8. linux宽松模式,SELinux 宽容模式(permissive) 强制模式(enforcing) 关闭(disabled) 几种模式之间的转换...
  9. HTML+CSS+JS实现 ❤️创意几何love字母特效❤️
  10. 西安邮电大学卓越班c语言面试题,西安邮电大学C语言实验报告.docx