作为深度学习训练数据的trick,结合交叉验证法,可以防止模型过早拟合。

早停法是一种被广泛使用的方法,在很多案例上都比正则化的方法要好。是在训练中计算模型在验证集上的表现,当模型在验证集上的表现开始下降的时候,停止训练,这样就能避免继续训练导致过拟合的问题。其主要步骤如下:
1. 将原始的训练数据集划分成训练集和验证集
2. 只在训练集上进行训练,并每隔一个周期计算模型在验证集上的误差
3. 当模型在验证集上(权重的更新低于某个阈值;预测的错误率低于某个阈值;达到一定的迭代次数),则停止训练
4. 使用上一次迭代结果中的参数作为模型的最终参数

如下图之后的某个epoch,模型的验证误差逐渐上升,模型出现过拟合,所以需要提前停止训练,早停法主要是训练时间和泛化错误之间的权衡。不同的停止标准也是给我们带来不同的效果。

pytorch实现早停法

#Train the Model using Early Stopping
def train_model(model, batch_size, patience, n_epochs):# to track the training loss as the model trainstrain_losses = []# to track the validation loss as the model trainsvalid_losses = []# to track the average training loss per epoch as the model trainsavg_train_losses = []# to track the average validation loss per epoch as the model trainsavg_valid_losses = [] # initialize the early_stopping objectearly_stopping = EarlyStopping(patience=patience, verbose=True)for epoch in range(1, n_epochs + 1):#################### train the model ####################model.train() # prep model for trainingfor batch, (data, target) in enumerate(train_loader, 1):# clear the gradients of all optimized variablesoptimizer.zero_grad()# forward pass: compute predicted outputs by passing inputs to the modeloutput = model(data)# calculate the lossloss = criterion(output, target)# backward pass: compute gradient of the loss with respect to model parametersloss.backward()# perform a single optimization step (parameter update)optimizer.step()# record training losstrain_losses.append(loss.item())######################    # validate the model #######################model.eval() # prep model for evaluationfor data, target in valid_loader:# forward pass: compute predicted outputs by passing inputs to the modeloutput = model(data)# calculate the lossloss = criterion(output, target)# record validation lossvalid_losses.append(loss.item())# print training/validation statistics # calculate average loss over an epochtrain_loss = np.average(train_losses)valid_loss = np.average(valid_losses)avg_train_losses.append(train_loss)avg_valid_losses.append(valid_loss)epoch_len = len(str(n_epochs))print_msg = (f'[{epoch:>{epoch_len}}/{n_epochs:>{epoch_len}}] ' +f'train_loss: {train_loss:.5f} ' +f'valid_loss: {valid_loss:.5f}')print(print_msg)# clear lists to track next epochtrain_losses = []valid_losses = []# early_stopping needs the validation loss to check if it has decresed, # and if it has, it will make a checkpoint of the current modelearly_stopping(valid_loss, model)if early_stopping.early_stop:print("Early stopping")break# load the last checkpoint with the best modelmodel.load_state_dict(torch.load('checkpoint.pt'))return  model, avg_train_losses, avg_valid_losses

pytorch早停法相关推荐

  1. 【pytorch EarlyStopping】深度学习之早停法入门·相信我,一篇就够。

    这个方法更好的解决了模型过拟合问题. EarlyStopping的原理是提前结束训练轮次来达到"早停"的目的,故训练轮次需要设置的大一点以求更好的早停(比如可以设置100epoch ...

  2. R语言构建xgboost模型使用早停法训练模型(early stopping):自定义损失函数(目标函数,loss function)、评估函数(evaluation function)

    R语言构建xgboost模型使用早停法训练模型(early stopping):自定义损失函数(目标函数.loss function.object function).评估函数(evaluation ...

  3. keras构建前馈神经网络(feedforward neural network)进行分类模型构建基于早停法(Early stopping)

    keras构建前馈神经网络(feedforward neural network)进行分类模型构建基于早停法(Early stopping) 当我们训练深度学习神经网络的时候通常希望能获得最好的泛化性 ...

  4. 深度学习技巧之Early Stopping(早停法)

    深度学习技巧之Early Stopping(早停法) | 数据学习者官方网站(Datalearner) 当我们训练深度学习神经网络的时候通常希望能获得最好的泛化性能(generalization pe ...

  5. 深度学习——早停法(Early Stopping)

    学习链接:https://www.jianshu.com/p/9ab695d91459 https://www.datalearner.com/blog/1051537860479157 目的: 为了 ...

  6. Early Stopping 早停法原理与实现

    Early Stopping 训练深度学习神经网络的时候通常希望能获得最好的泛化性能,可以更好地拟合数据.但是所有的标准深度学习神经网络结构如全连接多层感知机都很容易过拟合. 当模型在训练集上表现很好 ...

  7. Early Stopping早停法

    参考: https://www.jianshu.com/p/9ab695d91459

  8. Earlystopping(早停法)

    Earlystopping 简介 当我们训练深度学习神经网络的时候通常希望能获得最好的泛化性能(generalization performance,即可以很好地拟合数据). 但是所有的标准深度学习神 ...

  9. pytorch使用早停策略

    文章目录 早停的目的与流程 早停策略 pytorch使用示例 参考网站 早停的目的与流程 目的:防止模型过拟合,由于深度学习模型可以无限迭代下去,因此希望在即将过拟合时.或训练效果微乎其微时停止训练. ...

最新文章

  1. python3.6 messagebox_Python Tkinter GUI编程入门介绍
  2. RESTful编程究竟是什么?
  3. PDF搜索、转换与处理类网站
  4. JAVA设计模式--单例模式
  5. 趣谈设计模式 | 桥接模式(Bridge):将抽象与实现分离
  6. 微信小程序之 SideBar(侧栏分类)
  7. PHP开发框架[国内框架]
  8. (王道408考研数据结构)第五章树-第三节1:二叉树遍历(先序、中序和后序)
  9. BGP——邻居状态机+报文分析(总结)
  10. python基础学习(四)if判断语句
  11. 软件也要歧视大龄程序员吗?
  12. Redis实现计数器---接口防刷
  13. MATLAB带通滤波器开始端和结尾端数据异常(解决的小技巧)
  14. Android应用程序访问linux驱动第二步:实现并测试hardware层
  15. 关闭windows自动更新N种方法
  16. 初探信息科学中“三个世界”模型
  17. 【听课笔记】复旦大学遗传学_07基因表达调控
  18. web开发网页嵌入flash
  19. python数据收集整理教案_《数据收集整理》教学设计
  20. 沃丰科技:AI赋能泛CRM,为新企服扬风鼓帆

热门文章

  1. CASE_05 基于FPGA的DDS信号发生器
  2. 电力电子技术笔记(6)——电力电子器件的驱动
  3. TeraTerm与TTL(Tera Term Language)
  4. android 音量按键,Android 音量键的监听
  5. 内蒙古计算机职业高中高考分数线,内蒙古高考本科三批和高职高专录取分数线公布...
  6. Java300基础超适合零基础童鞋学习
  7. SAP中如何查看年度结转已完成
  8. conan入门(五):conan 交叉编译引用第三方库示例
  9. 扔旧被子扔掉霉运_您应该扔掉所有高科技产品盒吗?
  10. Glide加载圆角图片不显示问题