Pytorch:model.train()和model.eval()用法和区别,以及model.eval()和torch.no_grad()的区别
model.train()和model.eval()的区别主要在于Batch Normalization和Dropout两层。
model.train()
官方文档
启用 Batch Normalization 和 Dropout。
如果模型中有BN层(Batch Normalization)和 Dropout,需要在训练时添加model.train()
。model.train()
是保证BN层能够用到每一批数据的均值和方差。对于Dropout,model.train()
是随机取一部分网络连接来训练更新参数。
model.eval()
官方文档
不启用 Batch Normalization 和 Dropout。
如果模型中有BN层(Batch Normalization)和Dropout,在测试时添加model.eval()
。model.eval
()是保证BN层能够用全部训练数据的均值和方差,即测试过程中要保证BN层的均值和方差不变。对于Dropout,model.eval()是利用到了所有网络连接,即不进行随机舍弃神经元。
训练完train样本后,生成的模型model要用来测试样本。在model(test)之前,需要加上model.eval()
,否则的话,有输入数据,即使不训练,它也会改变权值。这是model中含有BN层和Dropout所带来的的性质。
在做one classification的时候,训练集和测试集的样本分布是不一样的,尤其需要注意这一点。
model.train()和model.eval() 源码解析
model.train()
和model.eval()
对应的源代码,如下所示,但是仅仅关注这一部分是不够的,现在需要记住当前的self.training的值是True还是False。
def train(self, mode=True):r"""Sets the module in training mode.This has any effect only on certain modules. See documentations ofparticular modules for details of their behaviors in training/evaluationmode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,etc.Returns:Module: self"""self.training = modefor module in self.children():module.train(mode)return selfdef eval(self):return self.train(False)
下边以Dropout为例,进入其对应的源代码,下方对应的self.training就是第一步中的self.training,原因在于Dropout继承了 _DropoutNd
类,而 _DropoutNd
由继承了Module
类,Module
类中自带变量self.training
,通过这种方法,来控制train/eval
模型下是否进行Dropout
。
class Dropout(_DropoutNd):'''balabala'''@weak_script_methoddef forward(self, input):return F.dropout(input, self.p, self.training, self.inplace)
分析原因
使用PyTorch进行训练和测试时一定注意要把实例化的model指定train/eval
,eval()
时,框架会自动把BN和Dropout固定住,不会取平均,而是用训练好的值,不然的话,一旦test的batch_size过小,很容易就会被BN层导致生成图片颜色失真极大!!!!!!
# 定义一个网络
class Net(nn.Module):def __init__(self, l1=120, l2=84):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, l1)self.fc2 = nn.Linear(l1, l2)self.fc3 = nn.Linear(l2, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# 实例化这个网络Model = Net()# 训练模式使用.train()Model.train(mode=True)# 测试模型使用.eval()Model.eval()
为什么PyTorch会关注我们是训练还是评估模型?最大的原因是dropout和BN层(以dropout为例)。这项技术在训练中随机去除神经元。
想象一下,如果右边被删除的神经元(叉号)是唯一促成正确结果的神经元。一旦我们移除了被删除的神经元,它就迫使其他神经元训练和学习如何在没有被删除神经元的情况下保持准确。这种dropout提高了最终测试的性能,但它对训练期间的性能产生了负面影响,因为网络是不全的。
下面我们看一个我们写代码的时候常遇见的错误写法:
在这个特定的例子中,似乎每50次迭代就会降低准确度。
如果我们检查一下代码, 我们看到确实在train函数中设置了训练模式。
def train(model, optimizer, epoch, train_loader, validation_loader):model.train() # ???????????? 错误的位置for batch_idx, (data, target) in experiment.batch_loop(iterable=train_loader):# model.train() # 正确的位置,保证每一个batch都能进入model.train()的模式data, target = Variable(data), Variable(target)# Inferenceoutput = model(data)loss_t = F.nll_loss(output, target)# The iconic grad-back-step triooptimizer.zero_grad()loss_t.backward()optimizer.step()if batch_idx % args.log_interval == 0:train_loss = loss_t.item()train_accuracy = get_correct_count(output, target) * 100.0 / len(target)experiment.add_metric(LOSS_METRIC, train_loss)experiment.add_metric(ACC_METRIC, train_accuracy)print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx, len(train_loader),100. * batch_idx / len(train_loader), train_loss))with experiment.validation():val_loss, val_accuracy = test(model, validation_loader) # ????????????experiment.add_metric(LOSS_METRIC, val_loss)experiment.add_metric(ACC_METRIC, val_accuracy)
这个问题不太容易注意到,在循环中我们调用了test函数。
def test(model, test_loader):model.eval()# ...
在test函数内部,我们将模式设置为eval。这意味着,如果我们在训练过程中调用了test函数,我们就会进eval模式,直到下一次train函数被调用。这就导致了每一个epoch中只有一个batch使用了dropout ,这就导致了我们看到的性能下降。
修复很简单,我们将model.train()
向下移动一行,让其在训练循环中。理想的模式设置是尽可能接近推理步骤,以避免忘记设置它。修正后,我们的训练过程看起来更合理,没有中间的峰值出现。
补充:model.eval()和torch.no_grad()的区别
在PyTorch中进行validation/test时,会使用model.eval()
切换到测试模式,在该模式下:
1.主要用于通知dropout层和BN层在training和validation/test模式间切换:
- 在train模式下,dropout网络层会按照设定的参数p,设置保留激活单元的概率(保留概率=p)。BN层会继续计算数据的mean和var等参数并更新。
- 在eval模式下,dropout层会让所有的激活单元都通过,而BN层会停止计算和更新mean和var,直接使用在训练阶段已经学出的mean和var值。
2.eval模式不会影响各层的gradient计算行为,即gradient计算和存储与training模式一样,只是不进行反向传播(back probagation)。
而with torch.no_grad()
则主要是用于停止autograd模块的工作,以起到加速和节省显存的作用。它的作用是将该with语句包裹起来的部分停止梯度的更新,从而节省了GPU算力和显存,但是并不会影响dropout和BN层的行为。
如果不在意显存大小和计算时间的话,仅仅使用model.eval()
已足够得到正确的validation/test的结果;而with torch.no_grad()
则是更进一步加速和节省gpu空间(因为不用计算和存储梯度),从而可以更快计算,也可以跑更大的batch来测试。
参考文献
pytorch中model.train()和model.eval()的区别
Pytorch:model.train()和model.eval()用法和区别,以及model.eval()和torch.no_grad()的区别相关推荐
- 【Pytorch】model.train()和model.eval()用法和区别,以及model.eval()和torch.no_grad()的区别
model.train() 启用 Batch Normalization 和 Dropout 如果模型中有BN层(Batch Normalization)和Dropout,需要在训练时添加model. ...
- 详解Pytorch中的requires_grad、叶子节点与非叶子节点、with torch.no_grad()、model.eval()、model.train()、BatchNorm层
requires_grad requires_grad意为是否需要计算梯度 使用backward()函数反向传播计算梯度时,并不是计算所有tensor的梯度,只有满足下面条件的tensor的梯度才会被 ...
- (深入理解)model.eval() 、model.train()以及torch.no_grad() 的区别
文章目录 简要版解释 深入版解释 简要版解释 在PyTorch中进行validation或者test的时侯,会使model.eval()切换到测试模式,在该模式下,model.training=Fas ...
- 【Pytorch】model.train() 和 model.eval() 原理与用法
文章目录 一.两种模式 二.功能 1. model.train() 2. model.eval() 为什么测试时要用 model.eval() ? 3. 总结与对比 三.Dropout 简介 参考链接 ...
- Pytorch model.train()
文章目录 1.前言 2.作用及原因 2.1.Batch Normalization 2.1.1训练时的BN层 2.1.2测试时的BN层 2.2.Dropout 3.总结 1.前言 在使用Pytorch ...
- 【pytorch】model.train和model.eval用法及区别详解
使用PyTorch进行训练和测试时一定注意要把实例化的model指定train/eval,eval()时,框架会自动把BN和DropOut固定住,不会取平均,而是用训练好的值,不然的话,一旦test的 ...
- Pytorch: model.eval(), model.train() 讲解
文章目录 1. model.eval() 2. model.train() 两者只在一定的情况下有区别:训练的模型中含有dropout 和 batch normalization 1. model.e ...
- Pytorch的model.train() model.eval() torch.no_grad() 为什么测试的时候不调用loss.backward()计算梯度还要关闭梯度
使用PyTorch进行训练和测试时一定注意要把实例化的model指定train/eval model.train() 启用 BatchNormalization 和 Dropout 告诉我们的网络,这 ...
- model.train()、model.eval()、optimizer.zero_grad()、loss.backward()、optimizer.step作用及原理详解【Pytorch入门手册】
1. model.train() model.train()的作用是启用 Batch Normalization 和 Dropout. 如果模型中有BN层(Batch Normalization)和D ...
最新文章
- Linux下无法进入windows的NTFS分区并挂载错误的问题的解决方法
- 将Excel的数据导入DataGridView中(转)
- 从构建分布式秒杀系统聊聊验证码
- asp.net远程调用WebService的两种方法
- Spark数据倾斜的完美解决
- SpringCloud企业实战专栏
- 1.阿里云RDS配置白名单,实例,外网地址,mysql数据库。
- 把view放在地图覆盖物上
- 用QBE语言实现关系代数
- 桌面版IDE将迎终结,Github发布代码空间Codespaces | 凌云时刻
- 三维姿态捕捉_三维人脸识别的方法有哪几种
- 您有新的订单提示音在线试听_iPhone修改微信提示音,支持全部机型,无需越狱...
- java中订单流水号_订单流水号的生成
- 美团工作10个月心得
- VBA实现EXCEL随机本地随机刷题
- SQL- join多表关联
- php不使用框架,导出Excel,这里有代码,全解
- 服务器名称指示(SNI)是什么东东?
- RabbitMQ fanout广播消息使用匿名队列
- Intelligent Parking Building
热门文章
- Vim工具打开、编辑、保存文件
- 同步异步阻塞非阻塞详解
- 虎牙直播数据采集,为数据分析做储备,Python爬虫120例之第24例
- genglinglong-java-day01
- rails 的 Helpers
- 降低 CPU 占用率的方法
- 放俩算法上来吧 (difficulty:easy)
- 网页端实现一键打印功能,H5,热敏打印机,普通打印机
- mysql limit 5 5 正确_关于SELECT * FROM tb_book LIMIT 5,10描述正确的是( )。 (5.0分)_学小易找答案...
- gma3600显卡linux,Intel GMA 3600 (简体中文)