前言:最近在把两个模型的代码整合到一起,发现有一个模型的代码整合后性能大不如前,但基本上是源码迁移,找了一天原因才发现是因为model.eval()和model.train()放错了位置!!!故在此介绍一下pytroch框架下model.train()、model.eval()的作用和不同点。

一、model.train、model.eval

1.model.train和model.eval放在代码什么位置

简单的说:model.train放在网络训练前,model.eval放在网络测试前。

常见的位置摆放错误(也是我犯的错误)有把model.train()放在for epoch in range(epoch):前面,同时在test或者val(测试或者评估函数)中只放置model.eval,这就导致了只有第一个epoch模型训练是使用了model.train(),之后的epoch模型训练时都采用model.eval().可能会影响训练好模型的性能。
修改方式:可以在test函数里return前面添加model.train()或者把model.train()放到for epoch in range(epoch):语句下面。

model.train()
for epoch in range(epoch):for train_batch in train_loader:...zhibiao = test(epoch, test_loader, model)def test(epoch, test_loader, model):model.eval()for test_batch in test_loader:...return zhibiao

2.model.train和model.eval有什么作用

model.train()和model.eval()的区别主要在于Batch NormalizationDropout两层。
如果模型中有BN层(Batch Normalization)和Dropout,在测试时添加model.eval()。model.eval()是保证BN层能够用全部训练数据的均值和方差,即测试过程中要保证BN层的均值和方差不变。对于Dropout,model.eval()是利用到了所有网络连接,即不进行随机舍弃神经元。

下面是model.train 和model.eval的源码,可以看到是利用self.training = mode来判断是使用train还是eval。这个参数将传递到一些常用层,比如dropout、BN层等。

def train(self: T, mode: bool = True) -> T:r"""Sets the module in training mode.This has any effect only on certain modules. See documentations ofparticular modules for details of their behaviors in training/evaluationmode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,etc.Args:mode (bool): whether to set training mode (``True``) or evaluationmode (``False``). Default: ``True``.Returns:Module: self"""self.training = modefor module in self.children():module.train(mode)return selfdef eval(self: T) -> T:r"""Sets the module in evaluation mode.This has any effect only on certain modules. See documentations ofparticular modules for details of their behaviors in training/evaluationmode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,etc.This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.Returns:Module: self"""return self.train(False)

拿dropout层的源码举例,可以看到传递了self.training这个参数。

class Dropout(_DropoutNd):r"""During training, randomly zeroes some of the elements of the inputtensor with probability :attr:`p` using samples from a Bernoullidistribution. Each channel will be zeroed out independently on every forwardcall.This has proven to be an effective technique for regularization andpreventing the co-adaptation of neurons as described in the paper`Improving neural networks by preventing co-adaptation of featuredetectors`_ .Furthermore, the outputs are scaled by a factor of :math:`\frac{1}{1-p}` duringtraining. This means that during evaluation the module simply computes anidentity function.Args:p: probability of an element to be zeroed. Default: 0.5inplace: If set to ``True``, will do this operation in-place. Default: ``False``Shape:- Input: :math:`(*)`. Input can be of any shape- Output: :math:`(*)`. Output is of the same shape as inputExamples::>>> m = nn.Dropout(p=0.2)>>> input = torch.randn(20, 16)>>> output = m(input).. _Improving neural networks by preventing co-adaptation of featuredetectors: https://arxiv.org/abs/1207.0580"""def forward(self, input: Tensor) -> Tensor:return F.dropout(input, self.p, self.training, self.inplace)

3.为什么主要区别在于BN层和dropout层

在BN层中,主要涉及到四个需要更新的参数,分别是running_mean,running_var,weight,bias。这里的weight,bias是Pytorch官方实现中的叫法,有点误导人,其实weight就是gamma,bias就是beta。当然它这样的叫法也符合实际的应用场景。其实gamma,beta就是对规范化后的值进行一个加权求和操作running_mean,running_var是当前所求得的所有batch_size下的均值和方差,每经过一个mini_batch我们都会更新running_mean,running_var.为什么要更新它?因为测试的时候,往往是一个一个的图像feed至网络的,如果你在这里对其进行计算均值方差显然是不合理的,所以model.eval()这个语句就是控制BN层中的running_mean,running_std不更新。采用训练结束后的running_mean,running_std来规范化该张图像。

dropout层在训练过程中会随机舍弃一些神经元用来提高性能,但测试过程中如果还是测试的模型还是和训练时一样随机舍弃了一些神经元(不是原模型)这就和测试的本意相违背。因为测试的模型应该是我们最终得到的模型,而这个模型应该是一个完整的模型。

4.BN层和dropout层的作用
既然都讲到这了,不了解一些BN层和dropout层的作用就说不过去了。
BN层的原理和作用建议读一下这篇博客:神经网络中BN层的原理与作用

dropout是指在深度学习网络的训练过程中,对于神经网络单元,按照一定的概率将其暂时从网络中丢弃。注意是暂时,对于随机梯度下降来说,由于是随机丢弃,故而每一个mini-batch都在训练不同的网络。

大规模的神经网络有两个缺点:费时、容易过拟合

Dropout的出现很好的可以解决这个问题,每次做完dropout,相当于从原始的网络中找到一个更瘦的网络。因而,对于一个有N个节点的神经网络,有了dropout后,就可以看做是2^n个模型的集合了,但此时要训练的参数数目却是不变的,这就解决了费时的问题。

将dropout比作是有性繁殖,将基因随机进行拆分,可以将优秀的基因传下来,并且降低基因之间的联合适应性,使得复杂的大段大段基因联合适应性变成比较小的一个一个小段基因的联合适应性。

dropout也能达到同样的效果,它强迫一个神经单元,和随机挑选出来的其他神经单元共同工作,达到好的效果。消除减弱了神经元节点间的联合适应性,增强了泛化能力。

参考链接

pytorch中model.train()和model.eval()的区别
BN层(Pytorch)
神经网络中BN层的原理与作用————这篇博客写的贼棒
深度学习中Dropout的作用和原理

pytroch:model.train()、model.eval()的使用相关推荐

  1. Pytorch的model.train() model.eval() torch.no_grad() 为什么测试的时候不调用loss.backward()计算梯度还要关闭梯度

    使用PyTorch进行训练和测试时一定注意要把实例化的model指定train/eval model.train() 启用 BatchNormalization 和 Dropout 告诉我们的网络,这 ...

  2. model.train()、model.eval()、optimizer.zero_grad()、loss.backward()、optimizer.step作用及原理详解【Pytorch入门手册】

    1. model.train() model.train()的作用是启用 Batch Normalization 和 Dropout. 如果模型中有BN层(Batch Normalization)和D ...

  3. (深入理解)model.eval() 、model.train()以及torch.no_grad() 的区别

    文章目录 简要版解释 深入版解释 简要版解释 在PyTorch中进行validation或者test的时侯,会使model.eval()切换到测试模式,在该模式下,model.training=Fas ...

  4. 【pytorch】model.train和model.eval用法及区别详解

    使用PyTorch进行训练和测试时一定注意要把实例化的model指定train/eval,eval()时,框架会自动把BN和DropOut固定住,不会取平均,而是用训练好的值,不然的话,一旦test的 ...

  5. Pytorch: model.eval(), model.train() 讲解

    文章目录 1. model.eval() 2. model.train() 两者只在一定的情况下有区别:训练的模型中含有dropout 和 batch normalization 1. model.e ...

  6. 【Pytorch】model.train()和model.eval()用法和区别,以及model.eval()和torch.no_grad()的区别

    model.train() 启用 Batch Normalization 和 Dropout 如果模型中有BN层(Batch Normalization)和Dropout,需要在训练时添加model. ...

  7. model.train()与model.eval()的用法、Dropout原理、relu,sigmiod,tanh激活函数、nn.Linear浅析

    转载:原文地址-传送门 1.model.train()与model.eval()的用法 看别人的面经时,浏览到一题,问的就是这个.自己刚接触pytorch时套用别人的框架,会在训练开始之前写上mode ...

  8. 【Pytorch】model.train() 和 model.eval() 原理与用法

    文章目录 一.两种模式 二.功能 1. model.train() 2. model.eval() 为什么测试时要用 model.eval() ? 3. 总结与对比 三.Dropout 简介 参考链接 ...

  9. 详解Pytorch中的requires_grad、叶子节点与非叶子节点、with torch.no_grad()、model.eval()、model.train()、BatchNorm层

    requires_grad requires_grad意为是否需要计算梯度 使用backward()函数反向传播计算梯度时,并不是计算所有tensor的梯度,只有满足下面条件的tensor的梯度才会被 ...

最新文章

  1. ios的并发队列控制库
  2. 3个写进简历的京东AINLP项目实战
  3. VTK:可视化算法之HyperStreamline
  4. VTK:图表之ConstructTree
  5. AI 云原生浅谈:好未来 AI 中台实践
  6. RC4加密解密java算法
  7. typescript 学习
  8. Vision Transformer中的自监督学习
  9. 社交媒体电影视频网HTML5模板
  10. 通过命令在navicat中创建数据库及表结构
  11. java redis设置过期时间_Redis的内存回收原理,及内存过期淘汰策略详解
  12. 分布式文件系统HDFS原理篇
  13. 使用 TimeGAN 建模和生成时间序列数据
  14. 双球坐标系_【天文】教你认识三大天球坐标系!(上)
  15. 大陆地区OpenStack项目Core现状(截至2016年1月28日,转载自陈沙克日志)
  16. 鱼眼相机图像畸变校正
  17. IDEA使用Git远程推送出现push to origin/master was rejected错误解决方案
  18. 123457123457#0#---------com.ppGame.SeaPuzzleGame73--前拼后广--宝宝海洋拼图pp
  19. 在路由器 RT-AC68U 安装迅雷远程过程
  20. 华硕i7计算机配置,华硕主板显卡i7家用电脑配置清单推荐

热门文章

  1. Java中字符串转整型和整型转字符串
  2. 如何把ADS的圆图导入到origin中
  3. 机器学习、数据科学与金融行业 系列四:智能投顾、量化投资与机器学习
  4. 浅谈文档协作在工程设计中的应用——共享excel计算书
  5. 俄罗斯前总统叶利钦逝世
  6. Xcode 中的黄色文件夹/蓝色文件夹
  7. Python信贷风控模型:Adaboost,XGBoost,SGD, SVC,随机森林, KNN预测信贷违约支付
  8. 原生Javascript实现五子棋
  9. AOI检测光学成像标准
  10. java计算机毕业设计高校共享机房管理系统的设计与实现源码+系统+lw文档+mysql数据库+部署