1. 设置断点续训的目的

在遇到停电宕机,设备内存不足导致实验还没有跑完的情况下,如果没有使用断点续训,就需要从头开始训练,耗时费力。
断点续训主要保存的是网络模型的参数以及优化器optimizer的状态(因为很多情况下optimizer的状态会改变,比如学习率的变化)

2. 设置断点续训的方法

  1. 参数设置
    resume: 是否进行续训
    initepoch: 进行续训时的初始epoch
  2. checkpoint载入过程(这部分操作放在epoch循环前边)
resume = True      # 设置是否需要从上次的状态继续训练if resume:if os.path.isfile("results/{}_model.pth".format(save_name_pre)):print("Resume from checkpoint...")checkpoint = torch.load("results/{}_model.pth".format(save_name_pre))model.load_state_dict(checkpoint['model_state_dict'])optimizer.load_state_dict(checkpoint['optimizer_state_dict'])initepoch = checkpoint['epoch'] + 1print("====>loaded checkpoint (epoch{})".format(checkpoint['epoch']))else:print("====>no checkpoint found.")initepoch = 1   # 如果没进行训练过,初始训练epoch值为1
  1. 每一轮,checkpoint的存储过程,保存模型参数,优化器参数,轮数(这部分操作放在epoch循环里边)
# 保存断点if test_acc_1 > best_acc:best_acc = test_acc_1checkpoint = {"model_state_dict": model.state_dict(),"optimizer_state_dict": optimizer.state_dict(),"epoch": epoch}path_checkpoint = "results/{}_model.pth".format(save_name_pre)torch.save(checkpoint, path_checkpoint)

pytorch学习(一)pytorch中的断点续训相关推荐

  1. 断点续训 Pytorch 和 Tensorflow 框架 VGG16 模型 猫狗大战 鸢尾花分类

    神经网络训练模型的过程中,如果程序突然中断,竹篮打水一场空? >>>断点续训来解决! 目录 (1)Pytorch框架的断点续训(猫狗大战) (2)Tensorflow框架的断点续训( ...

  2. Pytorch学习 - Task5 PyTorch卷积层原理和使用

    Pytorch学习 - Task5 PyTorch卷积层原理和使用 1. 卷积层 (1)介绍 (torch.nn下的) 1) class torch.nn.Conv1d() 一维卷积层 2) clas ...

  3. Pytorch学习 - Task6 PyTorch常见的损失函数和优化器使用

    Pytorch学习 - Task6 PyTorch常见的损失函数和优化器使用 官方参考链接 1. 损失函数 (1)BCELoss 二分类 计算公式 小例子: (2) BCEWithLogitsLoss ...

  4. PyTorch学习记录——PyTorch进阶训练技巧

    PyTorch学习记录--PyTorch进阶训练技巧 1.自定义损失函数 1.1 以函数的方式定义损失函数 1.2 以类的方式定义损失函数 1.3 比较与思考 2.动态调整学习率 2.1 官方提供的s ...

  5. tensorflow1运用模型断点续训、恢复图和进行预测

    前言 本文是代码根据吴恩达深度学习第四课程第一周第二节作业图像分类识别修改而成,会简单介绍一下项目流程,然后介绍tensorflow1保存模型的两种方法,以及如何用模型预测. 项目流程简单介绍 这里直 ...

  6. 【神经网络扩展】:断点续训和参数提取

    课程来源:人工智能实践:Tensorflow笔记2 文章目录 前言 断点续训主要步骤 参数提取主要步骤 总结 前言 本讲目标:断点续训,存取最优模型:保存可训练参数至文本 断点续训主要步骤 读取模型: ...

  7. kera TensorBoard的可视化和断点续训同时处理

    一.实现可视化的步骤 ① 从keras.callbacks中导入Tensorboard类 from keras.callbacks import TensorBoard ② 在model.fit中添加 ...

  8. PyTorch学习记录——PyTorch生态

    Pytorch的强大并不仅局限于自身的易用性,更在于开源社区围绕PyTorch所产生的一系列工具包(一般是Python package)和程序,这些优秀的工具包极大地方便了PyTorch在特定领域的使 ...

  9. 速成pytorch学习——4天中阶API示范

    使用Pytorch的中阶API实现线性回归模型和和DNN二分类模型. Pytorch的中阶API主要包括各种模型层,损失函数,优化器,数据管道等等. 一,线性回归模型 1,准备数据 import nu ...

最新文章

  1. python输出到语音播放_用Python写一个语音播放软件
  2. Qt最新版5.12在Windows环境静态编译安装和部署的完整过程(VS2017)
  3. python时钟程序的设计总结_Python实现时钟显示效果思路详解
  4. 通过代码动态创建IIS站点
  5. mysql频繁查询出错_Mysql数据库频繁查询错误解决方案
  6. 请求的链式处理——职责链模式
  7. [SCOI2015]小凸玩矩阵 (匈牙利+二分)
  8. Uniapp组件之间传参
  9. Tensorflow练习题
  10. 03python面向对象编程5
  11. 从源码解析LinkedList集合
  12. axure rp9是什么软件?如何在Mac中安装使用?
  13. OLAP -- ODS 项目总结 -- BI 中的关键
  14. 免费下载文档:给你介绍几个实用的免费下载网址
  15. 金融量化之华泰多因子估值类显著性和IC值计算
  16. spine 导出纹理_Spine 的纹理打包器(texture packer)详解
  17. 苹果账号调查事件始末,Apple审核流程或有变
  18. AMD显卡更新UEFI GOP
  19. <C++>初识类的继承,用三行情诗打开继承的大门
  20. 微信小程序订阅消息wx.requestSubscribeMessage使用要点和requestSubscribeMessage:can only be invoked by userTAPgestur

热门文章

  1. #小写金额转大写金额
  2. [附源码]计算机毕业设计JAVA药品销售管理系统
  3. 震网三代cve_2017_8464漏洞复现
  4. 风速仪原理是什么你知道么?
  5. 总结十三:外在认知比实际行动更重要
  6. 【前端多宫格卡片自适应,再也不怕多宫格布局啦】
  7. cimcoeditv5怎样模拟刀路_CimcoEdit5使用教程 Cimco Edit5怎么启动使用
  8. windows删除文件夹提示找不到该项目
  9. ADUM1201在隔离RS232中的应用
  10. “黑天鹅”与“灰犀牛”不能混为一谈,揭开数据保护的“隐秘角落”