权重衰减(weight decay)是应对过拟合方法的常用方法。

方法

权重衰减等价于$L_2$范数正则化(regularization)。正则化是通过模型损失函数添加惩罚项来使得训练后的模型参数值较小,是应对过拟合的常用方法。

$L_2$范数正则化在模型原来的损失函数基础上添加$L_2$范数称惩罚项。从而得到训练所需最小的函数。$L_2$范数惩罚项是模型权重参数每个元素的平方和与一个正的常数的乘积。以线性回归为例:

其中$w_1,w_2$为权重参数,$b$是偏置参数,样本$i$的输入为$x^{(i)}, x^{(i)}$,标签为$y^{(i)}$,样本数为$n$。将权重参数用向量$w=[w_1,w_2]$表示,带$L_2$范数惩罚项的新函数为

其中超参数$k>0$。当权重参数均为0时,惩罚项最小。当$k$较大时,惩罚项在损失函数中的比重较大,这通常会使学到的权重参数的元素接近于0。当$k$设为0时,惩罚项完全不起作用。上述式子中$L_2$范数平法$||w||^2$展开后得到$w_1^2+w_2^2$。有了$L_2$范数的惩罚项后,在小批量的随机梯度下降中,权重$w_1,w_2$的迭代方式改为:

$L_2​$范数正则化让权重$w_1​$和$w_2​$先自乘小于1的数,再减去惩罚项中的梯度。因此,$L_2​$范数正则化又称权重衰减。权重衰减通过惩罚绝对值较大的模型参数为需要学习的模型增加了限制,这可能对过拟合有效。在实际中,有事也在惩罚项中添加偏差元素的平方和。

高维线性回归

使用下列函数生成样本标签:

其中噪音项$\varepsilon$服从N(0,1),p为维度。1

2

3

4

5

6

7

8

9

10

11

12import gluonbook as gb

from mxnet import autograd, gluon, init, nd

from mxnet.gluon import data as gdata, loss as gloss

n_train, n_test, num_inputs = 20, 100, 200

true_w, true_b = nd.ones((num_inputs, 1)) * 0.01, 0.05

features = nd.random.normal(shape=(n_train + n_test, num_inputs))

labels = nd.dot(features, true_w) + true_b

labels += nd.random.normal(scale=0.01, shape=labels.shape)

train_features, test_features = features[:n_train, :], features[n_train:, :]

train_labels, test_labels = labels[:n_train], labels[n_train:]

实现1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68import gluonbook as gb

from mxnet import autograd, gluon, init, nd

from mxnet.gluon import data as gdata, loss as gloss

import matplotlib.pyplot as plt

n_train, n_test, num_inputs = 20, 100, 200

true_w, true_b = nd.ones((num_inputs, 1)) * 0.01, 0.05

features = nd.random.normal(shape=(n_train + n_test, num_inputs))

labels = nd.dot(features, true_w) + true_b

labels += nd.random.normal(scale=0.01, shape=labels.shape)

train_features, test_features = features[:n_train, :], features[n_train:, :]

train_labels, test_labels = labels[:n_train], labels[n_train:]

# 初始化模型参数

def init_params():

w = nd.random.normal(scale=1, shape=(num_inputs, 1))

b = nd.zeros(shape=(1,))

w.attach_grad()

b.attach_grad()

return [w,b]

# 定义L2范数惩罚项

def l2_penalty(w):

return (w**2).sum() / 2

# 定义训练和测试

batch_size, num_epochs, lr = 1, 100, 0.003

net, loss = gb.linreg, gb.squared_loss

train_iter = gdata.DataLoader(gdata.ArrayDataset(train_features, train_labels), batch_size, shuffle=True)

def semilogy(x_vals, y_vals, x_label, y_label, x2_vals=None, y2_vals=None,

legend=None, figsize=(5.5, 2.5)):

plt.rcParams['figure.figsize'] = figsize

plt.xlabel(x_label)

plt.ylabel(y_label)

plt.semilogy(x_vals, y_vals)

if x2_vals and y2_vals:

plt.semilogy(x2_vals, y2_vals, linestyle=':')

plt.legend(legend)

plt.show()

def fit_and_plot(lambd):

w, b = init_params()

train_ls, test_ls = [], []

for _ in range(num_epochs):

for X, y in train_iter:

with autograd.record():

# 添加L2范数惩罚项

l = loss(net(X, w, b), y) + lambd * l2_penalty(w)

l.backward()

gb.sgd([w,b], lr, batch_size)

train_ls.append(loss(net(train_features, w, b), train_labels).mean().asscalar())

test_ls.append(loss(net(test_features, w, b), test_labels).mean().asscalar())

semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'loss',

range(1, num_epochs + 1), test_ls, ['train', 'test'])

print('L2 norm of w:', w.norm().asscalar())

# 不使用权重衰减

fit_and_plot(lambd=0)

# 使用权重衰减

# fit_and_plot(lambd=3)

当未使用权重衰减(lambd=0)时,训练集上的误差远小于测试集

L2 norm of w: 11.61194

使用权重衰减(lambd=3)时,训练误差虽然提高,但是测试集上的误差下降,过拟合得到一定程度上缓解,此时权重参数更接近0。

L2 norm of w: 0.046675965

权重衰减的Gluon实现1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68import gluonbook as gb

from mxnet import autograd, gluon, init, nd

from mxnet.gluon import data as gdata, loss as gloss, nn

import matplotlib.pyplot as plt

n_train, n_test, num_inputs = 20, 100, 200

true_w, true_b = nd.ones((num_inputs, 1)) * 0.01, 0.05

features = nd.random.normal(shape=(n_train + n_test, num_inputs))

labels = nd.dot(features, true_w) + true_b

labels += nd.random.normal(scale=0.01, shape=labels.shape)

train_features, test_features = features[:n_train, :], features[n_train:, :]

train_labels, test_labels = labels[:n_train], labels[n_train:]

def semilogy(x_vals, y_vals, x_label, y_label, x2_vals=None, y2_vals=None,

legend=None, figsize=(4.5, 2.5)):

plt.rcParams['figure.figsize'] = figsize

plt.xlabel(x_label)

plt.ylabel(y_label)

plt.semilogy(x_vals, y_vals)

if x2_vals and y2_vals:

plt.semilogy(x2_vals, y2_vals, linestyle=':')

plt.legend(legend)

plt.show()

# 定义训练和测试

batch_size, num_epochs, lr = 1, 100, 0.003

net, loss = gb.linreg, gb.squared_loss

train_iter = gdata.DataLoader(gdata.ArrayDataset(train_features, train_labels), batch_size, shuffle=True)

def fit_and_plot(wd):

net = nn.Sequential()

net.add(nn.Dense(1))

net.initialize(init.Normal(sigma=1))

# 对权重衰减,权重名称一般是以weight结尾

train_w = gluon.Trainer(net.collect_params('.*weight'), 'sgd', {'learning_rate': lr, 'wd': wd})

train_b = gluon.Trainer(net.collect_params('.*bias'), 'sgd', {'learning_rate': lr})

train_ls, test_ls = [], []

for _ in range(num_epochs):

for X,y in train_iter:

with autograd.record():

l = loss(net(X), y)

l.backward()

# 对两个Trainer分别调用step函数,从而分别更新权重和偏置

train_b.step(batch_size)

train_w.step(batch_size)

train_ls.append(loss(net(train_features), train_labels).mean().asscalar())

test_ls.append(loss(net(test_features), test_labels).mean().asscalar())

semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'loss',

range(1, num_epochs + 1), test_ls, ['train', 'test'])

print('L2 norm of w:', net[0].weight.data().norm().asscalar())

# 不使用权重衰减

fit_and_plot(wd=0)

# 使用权重衰减

# fit_and_plot(wd=3)

深度学习的权重衰减是什么_深度学习-权重衰减相关推荐

  1. 深度置信网络预测算法matlab代码_深度学习双色球彩票中的应用研究资料

    点击蓝字关注我们 AI研习图书馆,发现不一样的世界 深度学习在双色球彩票中的应用研究 前言 人工神经网络在双色球彩票中的应用研究网上已经有比较多的研究论文和资料,之前比较火的AlphaGo中用到的深度 ...

  2. 强化学习论文分析4---异构网络_强化学习_功率控制《Deep Reinforcement Learning for Multi-Agent....》

    目录 一.文章概述 二.系统目标 三.应用场景 四.算法架构 1.微基站处----DQN 2.宏基站处---Actor-Critic 五.伪代码 六.算法流程图 七.性能表征 1.收敛时间 2.信道总 ...

  3. 机器学习学习吴恩达逻辑回归_机器学习基础:逻辑回归

    机器学习学习吴恩达逻辑回归 In the previous stories, I had given an explanation of the program for implementation ...

  4. 深度学习的权重衰减是什么_【深度学习理论】一文搞透Dropout、L1L2正则化/权重衰减...

    前言 本文主要内容--一文搞透深度学习中的正则化概念,常用正则化方法介绍,重点介绍Dropout的概念和代码实现.L1-norm/L2-norm的概念.L1/L2正则化的概念和代码实现- 要是文章看完 ...

  5. 深度学习的权重衰减是什么_权重衰减和L2正则化是一个意思吗?它们只是在某些条件下等价...

    权重衰减== L2正则化? 神经网络是很好的函数逼近器和特征提取器,但有时它们的权值过于专门化而导致过度拟合.这就是正则化概念出现的地方,我们将讨论这一概念,以及被错误地认为相同的两种主要权重正则化技 ...

  6. 深度学习 图像分类_深度学习时代您应该阅读的10篇文章了解图像分类

    深度学习 图像分类 前言 (Foreword) Computer vision is a subject to convert images and videos into machine-under ...

  7. 贝叶斯深度神经网络_深度学习为何胜过贝叶斯神经网络

    贝叶斯深度神经网络 Recently I came across an interesting Paper named, "Deep Ensembles: A Loss Landscape ...

  8. ann人工神经网络_深度学习-人工神经网络(ANN)

    ann人工神经网络 Building your first neural network in less than 30 lines of code. 用不到30行代码构建您的第一个神经网络. 1.W ...

  9. 乐器演奏_深度强化学习代理演奏的蛇

    乐器演奏 Ever since I watched the Netflix documentary AlphaGo, I have been fascinated by Reinforcement L ...

  10. 深度学习背后的数学_深度学习背后的简单数学

    深度学习背后的数学 Deep learning is one of the most important pillars in machine learning models. It is based ...

最新文章

  1. 【文本分类】混合CHI和MI的改进文本特征选择方法
  2. 标准博客 API .BLOG APIS
  3. 【深度学习】深入浅出卷积神经网络及实现!
  4. weblogic11g集群配置
  5. Leet Code OJ 203. Remove Linked List Elements [Difficulty: Easy]
  6. cortex m0启动代码详解
  7. .NET Core 使用 grpc 实现微服务
  8. Eclipse闪退解决方案
  9. php报错集合,centos7安装php5.6报错集合
  10. 信息化为五万教学点带来“优质教师”
  11. ABAQUS 转子动力学载荷
  12. idea 中 maven Process terminated
  13. java繁体转简体包,java繁体转简体
  14. IDEA 设置自动启动的浏览器
  15. 小米MIUI12.5手机降级教程,线刷
  16. esp32与0.96寸屏幕实现信息传输
  17. 电脑使用技巧(Win10修改窗口背景颜色)
  18. 成都计算机考研学校排名,成都十大考研学校排名
  19. 打印零与奇偶数 思路分析
  20. ncr管理系统_完全拆解小米智能电动车【图解】

热门文章

  1. 超六类网线与7类网线的区别,你知道吗?
  2. 蛋白胶条质谱鉴定实验
  3. 道格拉斯简化_简化组织变革:困惑的指南
  4. HAL库学习笔记-10 HAL库外设驱动框架概述
  5. python对图片进行裁剪_python利用四个坐标点对图片进行裁剪
  6. 腾讯QQ2019最新版 v9.1.0(24712) 免安装绿色版 显IP去广告完整版
  7. 2018DeeCamp笔试题目第二套B卷
  8. 物联网、大数据、云计算、人工智能之间的关系
  9. 桌面右下角出现“测试模式 Windows7 内部版本7601”怎么回事?
  10. 27岁学前端开发,3年前端开发工资待遇