目录

1、为什么要标准化(理解的直接跳过到这部分)

2、LayerNorm 解释

3、举例-只对最后 1 个维度进行标准化

4、举例-对最后 D 个维度进行标准化


1、为什么要标准化(理解的直接跳过到这部分)

Batch Normalization 的作用就是把神经元在经过非线性函数映射后向取值区间极限饱和区靠拢的输入分布强行拉回到均值为 0 方差为 1 的比较标准的正态分布的区间,使得非线性变换函数的输入值落入激活函数比较敏感的区域,这样会让让梯度变大,由此避免了梯度消失的问题。而梯度变大也意味着学习收敛速度快,能大大加快训练速度。

BN 的计算过程如下:

1、计算均值

2、计算方差

3、标准化:减均值后除以标准差(注意是标准差,不是方差)

4、仿射变换

使用 netron 工具可视化 LayerNorm 层的计算流图如下:

2、LayerNorm 解释

LayerNorm 是一个类,用来实现对 tensor 的层标准化,实例化时定义如下:

LayerNorm(normalized_shape, eps = 1e-5, elementwise_affine = True, device=None, dtype=None)

以一个 shape 为 (3, 4) 的 tensor 为例。LayerNorm 里面主要会用到三个参数:

normalized_shape:要实行标准化的最后 D 个维度,可以是一个 int 整数(必须等于tensor的最后一个维度的大小,不能是中间维度的大小),使用示例 tensor 的话此时这个整数必须为 normalized_shape=4,代表标准化 tensor 的最后一维。另外也可以是一个列表,但这个列表也必须是最后的 D 个维度的列表,如示例 tensor 的话就必须是 normalized_shape=[3, 4]

eps:为了防止标准差为零时分母为零,设置的极小值,默认是1e-5,也可以自己设置。

elementwise_affine:是否需要仿射变换。仿射变换需要两个可学习参数 γ 和 β:把标准化的结果乘以缩放系数 γ 再加上偏置系数 β。仿射变换是为了保证非线性的获得。

举个例子,我们有下面一个 shape 为 (3, 4) 的数组,并把它转化为 tensor。

import torch
import torch.nn as nn
import numpy as np
a = np.array([[1, 20, 3, 4],[5, 6, 7, 8,],[9, 10, 11, 12]], dtype=np.double)
b = torch.from_numpy(a).type(torch.FloatTensor)

3、举例-只对最后 1 个维度进行标准化

现在想计算对一个维度进行标准化,即对 [1, 20, 3, 4][5, 6, 7, 8,][9, 10, 11, 12] 分别标准化,可以像下面这样操作:

layer_norm = nn.LayerNorm(4, eps=1e-6) # 最后一个维度大小为4,因此normalized_shape是4
c = layer_norm(b)
print(c)
# 结果:
tensor([[-0.7913,  1.7144, -0.5275, -0.3956],[-1.3416, -0.4472,  0.4472,  1.3416],[-1.3416, -0.4472,  0.4472,  1.3416]],grad_fn=<NativeLayerNormBackward0>)

怎么验证对不对呢?我们可以使用 np 对数组 a 手动计算下标准化看看:

mean_a = np.mean(a, axis=1)  # 计算最后一个维度的均值 = [7. 6.5 10.5]
var_a = np.var(a, axis=1)    # 计算最后一个维度的方差 = [57.5 1.25 1.25]
# 对最后一个维度做标准化 减均值后除以标准差
a[0, :] = (a[0, :] - mean_a[0]) / np.sqrt(var_a[0])
a[1, :] = (a[1, :] - mean_a[1]) / np.sqrt(var_a[1])
a[2, :] = (a[2, :] - mean_a[2]) / np.sqrt(var_a[2])
print(a)
# 输出结果:
[[-0.79125657  1.71438923 -0.52750438 -0.39562828][-1.34164079 -0.4472136   0.4472136   1.34164079][-1.34164079 -0.4472136   0.4472136   1.34164079]]

这时发现与 torch 的 LayerNorm  计算结果想通过,印证了上述的解释。

4、举例-对最后 D 个维度进行标准化

这是个二维tensor,假设我们要对最后二维进行标准化,也即对所有数据标准化,可以令 normalized_shape=[3, 4],如下:

layer_norm = nn.LayerNorm([3, 4], eps=1e-6)
c = layer_norm(b)
print(c)
# 计算结果:
tensor([[-1.4543e+00,  2.4932e+00, -1.0388e+00, -8.3105e-01],[-6.2329e-01, -4.1553e-01, -2.0776e-01,  1.1921e-07],[ 2.0776e-01,  4.1553e-01,  6.2329e-01,  8.3105e-01]],grad_fn=<NativeLayerNormBackward0>)

怎么做验证呢?也让 np 在所有数据上做标准化:

mean_a = np.mean(a)  # 计算所有数据的均值,返回标量
var_a = np.var(a)    # 计算所有数据的方差,返回标量
a = (a - mean_a) / np.sqrt(var_a)  # 对整体做标准化
print(a)
# 输出结果
[[-1.45434106  2.4931561  -1.03881504 -0.83105203][-0.62328902 -0.41552602 -0.20776301  0.        ][ 0.20776301  0.41552602  0.62328902  0.83105203]]

np 手动计算与 torch 的计算结果相同。

pytorch 层标准化 LayerNorm 的用法相关推荐

  1. Pytorch中的collate_fn函数用法

    Pytorch中的collate_fn函数用法 官方的解释:   Puts each data field into a tensor with outer dimension batch size ...

  2. 批标准化(batch normalization)与层标准化(layer normalization)比较

    批标准化(batch normalization,BN)与层标准化(layer normalization,LN)应该都是为了解决网络训练过程中的协变量漂移问题. BN与LN的归一化方法都是先减均值, ...

  3. pytorch 深入理解 tensor.scatter_ ()用法

    pytorch 深入理解 tensor.scatter_ ()用法 在 pytorch 库下理解 torch.tensor.scatter()的用法.作者在网上搜索了很多方法,最后还是觉得自己写一篇更 ...

  4. Pytorch——批标准化(层归一化)

    文章目录 1.前言 2.普通数据归一化 3.层归一化 4.Batch Normalization 添加位置 5.Batch Normalization 效果 6.BN 算法 1.前言 今天我们会来聊聊 ...

  5. Pytorch中nn.Conv2d的用法

    官网链接: nn.Conv2d     Applies a 2D convolution over an input signal composed of several input planes. ...

  6. Pytorch/Python中item()的用法

    前言 在使用Pytorch训练模型时,用到python中的item()函数,如: train_loss += loss.item() 现对item()函数用法做出总结.item()函数的作用是从包含单 ...

  7. pytorch的size和shape用法

      有别于numpy中size的用法(用来计算数组和矩阵中所有元素的个数),pytorch的size具有和shape一样计算矩阵维度大小的作用. 上代码~ import torch import nu ...

  8. pcb板子制作各层的解释和用法

    由上到下 正面丝印层,正面的印字 正面阻焊层,正面无需焊接的涂覆 正面布线层,正面的铜箔 背面布线层,背面的铜箔 背面阻焊层,背面无需阻焊的涂覆 背面丝印层,背面的印字 2009-09-02 15:4 ...

  9. Pytorch torch.topk()的简单用法

    官方文档:https://pytorch.org/docs/stable/generated/torch.topk.html?highlight=topk#torch.topk 由于numpy本身是没 ...

最新文章

  1. linux服务器上nginx日志访问量统计命令
  2. java如何调用系统保存框_java使用poi实现excel导出之后如何弹出保存提示框
  3. 计算机硬件市场调查实验报告,计算机组装与维护实训报告范例.doc
  4. 《51单片机应用开发从入门到精通》——2.6 中断控制功能的作用
  5. ASSERT: “QGLFunctions::isInitialized(d_ptr)“ - Runtime Exception
  6. 从Microsoft Teams技术栈看前端技术发展趋势
  7. 计算机应用乘法,计算机系统原理(十) 二进制整数的乘法运算和除法运算
  8. 物理学家张首晟:如果世界末日来临,我会带这几句话上诺亚方舟|研习社演讲实录...
  9. Jquery获取列表中的值和input单选、多选框控制选中与取消
  10. jq过滤替换敏感词_如何用python简单过滤敏感信息
  11. 【易语言界面开发系列教程之(EX_UI使用系列教程(14)--EX组件(组合框))】
  12. 基于SOLIDWORKS Simulation的有限元分析法实例应用
  13. 【UCOSIII操作系统】任务篇(2)相关API函数
  14. 机器学习数学基础十:相关分析
  15. 【最佳实践】行云管家数据库运维审计解决方案
  16. Stata制作限制立方样条(RCS)(2)
  17. 【C语言】初识C语言(中篇)
  18. iOS 将状态栏设置成白色
  19. 2014年国人开发的最热门的开源软件TOP 100
  20. python调用打印机打印pdf_python连接打印机实现打印文档、图片、pdf文件等功能

热门文章

  1. mysql 在linux环境下导出,window下导入报ASCII '\0' appeared in the statement
  2. 浅谈微型真空气泵、空气采样泵的选用
  3. 尬聊器(伪聊天机器人)
  4. 美国盗版党(Pirate Party)
  5. arya-sites模块的主要类
  6. 最大熵模型(maximum entropy model)
  7. 软件之聊天工具:QQ,MSN,Google talk,Skype, Lync
  8. 计算机技术中的多媒体是什么,在多媒体计算机技术中,媒体含义一般指()。A中介B介质C信息的载体D存储介质 - 试题答案网问答...
  9. oracle里的ols机制,Oracle DV和OLS以及VPD的区别(转)
  10. 3步了解APP渠道应该怎样建设评估体系(上)