权重初始化

1.什么是权重初始化

权重初始化(weight initialization)又称参数初始化,在深度学习模型训练过程的本质是对weight(即参数 W)进行更新,但是在最开始训练的时候是无法更新的,这需要每个参数有相应的初始值。在进行权重初始化后,神经网络就可以对权重参数w不停地迭代更新,以达到较好的性能。

2.为什么需要好的权重初始化

网络训练的过程中, 容易出现梯度消失(梯度特别的接近0)和梯度爆炸(梯度特别的大)的情况,导致大部分反向传播得到的梯度不起作用或者起反作用. 研究人员希望能够有一种好的权重初始化方法: 让网络前向传播或者反向传播的时候, 卷积的输出和前传的梯度比较稳定. 合理的方差既保证了数值一定的不同, 又保证了数值一定的稳定.(通过卷积权重的合理初始化, 让计算过程中的数值分布稳定)

3.权重初始化方法

预训练权重

使用预训练权重

在计算机视觉领域中,迁移学习通常是通过使用预训练模型来表示的。预训练模型是在大型基准数据集上训练的模型,用于解决相似的问题。由于训练这种模型的计算成本较高,因此,导入已发布的成果并使用相应的模型是比较常见的做法。例如,在目标检测任务中,首先要利用主干神经网络进行特征提取,这里使用的backbone一般就是VGG、ResNet等神经网络,因此在训练一个目标检测模型时,可以使用这些神经网络的预训练权重来将backbone的参数初始化,这样在一开始就能提取到比较有效的特征。

可能大家会有疑问,预训练权重是针对他们数据集训练得到的,如果是训练自己的数据集还能用吗?预训练权重对于不同的数据集是通用的,因为特征是通用的。一般来讲,从0开始训练效果会很差,因为权值太过随机,特征提取效果不明显。对于目标检测模型来说,一般不从0开始训练,至少会使用主干部分的权值,虽然有些论文提到了可以不用预训练,但这主要是因为他们的数据集比较大而且他们的调参能力很强。如果从0开始训练,网络在前几个epoch的Loss可能会非常大,并且多次训练得到的训练结果可能相差很大,因为权重初始化太过随机。

冻结训练

冻结训练其实也是迁移学习的思想,在目标检测任务中用得十分广泛。因为目标检测模型里,主干特征提取部分所提取到的特征是通用的,把backbone冻结起来训练可以加快训练效率,也可以防止权值被破坏。在冻结阶段,模型的主干被冻结了,特征提取网络不发生改变,占用的显存较小,仅对网络进行微调。在解冻阶段,模型的主干不被冻结了,特征提取网络会发生改变,占用的显存较大,网络所有的参数都会发生改变。举个例子,如果在解冻阶段设置batch_size为4,那么在冻结阶段有可能可以把batch_size设置到8。

断点恢复

在上面冻结训练和解冻训练的代码里设置了不同的batch_size,前者是8后者是4,有可能冻结训练的时候显存是够用的,结果解冻后显存不足了,这个时候需要重新把解冻训练阶段的batch_size调得更小一点。但是网络才训练了冻结阶段的50个epoch,backbone参数还是用的预训练权重呢,网络效果肯定不够好。难道要前功尽弃重新开始训练?这时候就要使用断点恢复技术了。其实断点恢复的思想很简单,就是把网络初始设置的model_path改为出错前保存好的权值文件,然后调整一下起始epoch和终止epoch即可,比如在前面提到的这种情况里,在第51个epoch报了错,那么可以把model_path修改为第50个epoch训练结束后保存的权值文件,然后把起始epoch调整成50就可以了。

断点恢复的应用范围非常非常广。最常见的情况就是代码跑到一半因为某些原因中断了(比如电脑突然死机重启这种不可抗力因素),又不想从头重新跑,那么就可以利用断点恢复训练的方法,这样可以节省不少时间。再比如,一个非常常见的情况,假如一开始设置了100个epoch,结果模型训练结束时,Loss还呈现下降的趋势,也就是模型还没有收敛,这种现象有可能就是epoch设置小了,所以可以把第100个epoch训练得到的权值文件当做初始权值文件再训练几个epoch看看,避免重新设置epoch从头训练。

当然,想要执行断点恢复首先需要把每个epoch得到的权值文件保存起来,这样才能修改model_path重新加载。断点恢复和常规的模型保存加载的区别其实就是epoch也要修改一下而已。保存权重可以用以下方法:

torch.save(model.state_dict(), "你要保存到的路径")

预训练和微调

假如我们现在要搭建一个网络模型来完成一个图像分类的任务,首先我们需要把网络的参数进行初始化,然后在训练网络的过程中不断对参数进行调整,直到网络的损失越来越小。在训练过程中,一开始初始化的参数会不断变化,如果结果已经满意了,那我们就可以把训练好的模型参数保存下来,以便训练好的模型可以在下次执行类似任务的时候获得比较好的效果。这个过程就是预训练(Pre-Training)。

假如在完成上面的模型训练后,我们又接到另一个类似的图像分类任务,这时我们就可以直接使用之前保存下来的模型参数作为这一次任务的初始化参数,然后在训练过程中依据结果不断进行修改,这个过程就是微调(Fine-Tuning)。

我们使用的神经网络越深,就需要越多的样本来进行训练,否则就很容易出现过拟合现象。比如我们想训练一个识别猫的模型,但是自己标注数据精力有限只标了100张,这时就可以考虑ImageNet数据集,可以在ImageNet上训练一个模型,然后使用该模型作为类似任务的初始化或特征提取器,这样既节省了时间和计算资源,又能很快地达到较好的效果。当然,采用预训练+微调也不是绝对有效的,上面识别猫的例子可以这样做是因为ImageNet里有猫的图像,所以可以认为是一个类似的数据集,如果是识别癌细胞的话,效果可能就不是那么好了。关于预训练和微调是有很多策略的,经验也很重要。

权重初始化与预训练权重相关推荐

  1. pytorch载入部分预训练权重

    文章目录 前言 方法一 方法二 前言 使用迁移学习的方法训练网络往往需要载入部分已训练好的网络权重,接下来介绍两种载入预训练权重的方法,第一种比较简单,第二种方法稍微复杂但是更加灵活. 方法一 先按原 ...

  2. 迁移学习、载入预训练权重和冻结权重

    迁移学习就是载入别人预训练好的权重,拿别人的训练好的参数作为我们自己模型的初始化参数,再在这个基础上继续优化.比起从头开始一点一点随机初始化,让模型胡乱地找梯度最优的方向,肯定是迁移学习快啦. 目录 ...

  3. PyTorch载入预训练权重方法和冻结权重方法

    载入预训练权重 1. 直接载入预训练权重 简单粗暴法: pretrain_weights_path = "./resnet50.pth" net.load_state_dict(t ...

  4. Pytorch迁移学习加载部分预训练权重

    迁移学习在图像分类领域非常常见,利用在超大数据集上训练得到的网络权重,迁移到自己的数据上进行训练可以节约大量的训练时间,降低欠拟合/过拟合的风险. 如果用原生网络进行迁移学习非常简单,其核心是 mod ...

  5. PyTorch 加载预训练权重

    前言  使用PyTorch官方提供的权重或者其他第三方提供的权重对相同模型的参数进行初始化,在数据量较少的前提下,可以帮助模型更快地收敛到最优点,达到更好的效果,即迁移学习.  在大部分的迁移学习场景 ...

  6. 预训练权重到底是个啥

    预训练权重,顾名思义,就是预先训练好的权重,这类权重是在大型数据集上进行训练的,训练出来的权重是普遍通用的,因此不必担心是否符合自己的实际情况,我们个人往往很难训练出预训练权重的效果.并且如果不使用预 ...

  7. 深度学习加载预训练权重好处

    深度学习加载预训练权重好处: 在模型开始训练前,使模型参数得到一个好的初始化,对于后面的训练学习有非常大的帮助.

  8. c++ opencv2 libtorch 读取预训练权重并进行预测 linux

    c++ opencv2 libtorch 读取预训练权重并进行预测 原文:https://oldpan.me/archives/pytorch-c-libtorch-inference 本篇使用的平台 ...

  9. torch编程-加载预训练权重-模型冻结-解耦-梯度不反传

    1)加载预训练权重 net = torchvision.models.resnet50(pretrained=False) # 构建模型 pretrained_model = torch.load(p ...

最新文章

  1. mybatis使用foreach进行批量保存
  2. franze kafka 游戏_The Franz Kafka Videogame
  3. 【深度学习】深度神经网络框架的探索(从Regression说起)
  4. DCMTK:类DcmSequence和DcmPixelSequence的测试程序
  5. 怎么设置表头字体大小_Excel斜线表头和三线表头是如何制作的?
  6. STM32 - 定时器的设定 - 基础 01 - Timer Base - Prescaler description - Upcounting mode
  7. python学习手册-Python 重点知识整理(基于Python学习手册第四版)
  8. sp_help 查看表结构 alter column修改字段长度
  9. python实现arxiv论文数据解析处理
  10. DS18B20使用说明
  11. SHFileOperation
  12. linux 关机 日志记录,linux查看开关机记录
  13. foobar2000播放的一些使用技巧
  14. 计算机系统最重要的是什么,操作系统最重要的两个作用是什么
  15. WinForm分页控件
  16. unbuntu20.0.4 显卡驱动安装,nividia-smi无效
  17. 关于瑞星杀毒软件对浏览器速度的影响
  18. QT之远程控制对方电脑
  19. mysql 对 GENERATED 字段更新时候报错
  20. Synctoy2.1使用定时任务0X1

热门文章

  1. Python爱好者,Python语言特点总结,拿走不谢!
  2. batchsize实验
  3. 常用的基于内容的推荐算法实现原理
  4. C#常用的几个ORM框架及简单对比
  5. CSS总结div中的内容垂直居中的五种方法
  6. 安卓平板电脑使用Termux编程环境配置
  7. Maxtang 大唐J6412四网口迷你主机安装NAS黑群晖教程
  8. Portapack应用开发教程(十) 猎狐功能和RSSI数值显示
  9. 如何用firefox开发者工具查看元素
  10. 教你手写DMA传输数据(看完这篇你就会手动写啦,保姆级讲解)---- 2020.3.31