两种方式

保存模型主要分为两类:
1、保存整个模型
2、保存模型参数

1、第一种

结构模型+模型参数

保存整个网络模型,加载整个网络模型(可能比较耗时)

# 保存方式1
torch.save(vgg16, "vgg16_model1.pth")
# 对应保存的方式1
model = torch.load("vgg16_model1.pth")
print(model)

2、第二种

只保存加载模型参数(推荐)

保存模型的权重参数(速度快,占内存少)

# 保存方式2
torch.save(vgg16.state_dict(),"vgg16_model2.pth")
# 对应保存的方式2
vgg16 = torchvision.models.vgg16(pretrained=False)
model2 = torch.load("vgg16_model2.pth")
vgg16.load_state_dict(model2)

假设网络为:
model = Net(), optimizer = optim.Adam(model.parameters(), lr=args.lr),
假设在某个epoch,要保存模型参数,优化器参数以及epoch
先建立一个字典,保存三个参数:

state = { ‘net’: model.state_dict(), ‘optimizer’: optimizer.state_dict(), ‘epoch’: epoch}

torch.save(state, "./project/mymodel.pth")

当想恢复某一阶段的训练时,那就可以读取之前保存的网络模型参数等。

mymodel= torch.load("./project/mymodel.pth")
model.load_state_dict(mymodel['net'])  #  加载之前的网络模型参数
optimizer.load_state_dict(mymodel['optimizer'])  # 加载之前的优化器的参数
start_epoch = mymodel['epoch'] + 1  #  加载新的训练回合数

深度学习——09模型的保存:torch.save()、加载:torch.load()相关推荐

  1. Tensorflow【实战Google深度学习框架】TensorFlow模型的保存与恢复加载

    我们使用TensorFlow进行模型的训练,训练好的模型需要保存,预测阶段我们需要将模型进行加载还原使用,这就涉及TensorFlow模型的保存与恢复加载. 总结一下Tensorflow常用的模型保存 ...

  2. 深度学习修炼(二)——数据集的加载

    文章目录 致谢 2 数据集的加载 2.1 框架数据集的加载 2.2 自定义数据集 2.3 准备数据以进行数据加载器训练 致谢 Pytorch自带数据集介绍_godblesstao的博客-CSDN博客_ ...

  3. Pytorch:模型的保存与加载 torch.save()、torch.load()、torch.nn.Module.load_state_dict()

    Pytorch 保存和加载模型后缀:.pt 和.pth 1 torch.save() [source] 保存一个序列化(serialized)的目标到磁盘.函数使用了Python的pickle程序用于 ...

  4. Tensorflow深度学习实战之(五)--保存与恢复模型

    文章目录 一.保存模型 二.恢复模型 三.使用模型预测 一.保存模型 在训练完Tensorflow模型为了方便对新的数据进行预测需要保存该模型,Tensorflow提供 tf.train.Saver( ...

  5. 深度学习实战——模型推理优化(模型压缩与加速)

    忆如完整项目/代码详见github:https://github.com/yiru1225(转载标明出处 勿白嫖 star for projects thanks) 目录 系列文章目录 一.实验思路综 ...

  6. 【深度学习中模型评价指标汇总(混淆矩阵、recall、precision、F1、AUC面积、ROC曲线、ErrorRate)】

    深度学习中模型好坏的所有评价指标汇总(混淆矩阵.recall.precision.F1score.AUC面积.ROC曲线.ErrorRate) 导航 0.混淆矩阵 1.AUC面积 2.ROC曲线 3. ...

  7. 浅谈深度学习:如何计算模型以及中间变量的显存占用大小

    原文链接:https://oldpan.me/archives/how-to-calculate-gpu-memory 前言 亲,显存炸了,你的显卡快冒烟了! torch.FatalError: cu ...

  8. NVIDIA GPUs上深度学习推荐模型的优化

    NVIDIA GPUs上深度学习推荐模型的优化 Optimizing the Deep Learning Recommendation Model on NVIDIA GPUs 推荐系统帮助人在成倍增 ...

  9. 【深度学习】模型训练过程可视化思路(可视化工具TensorBoard)

    [深度学习]模型训练过程可视化思路(可视化工具TensorBoard) 文章目录 1 TensorBoard的工作原理 2 TensorFlow中生成log文件 3 启动TensorBoard,读取l ...

最新文章

  1. python病毒扫描器_基于Python的病毒扫描机制
  2. Python第三周 学习笔记(2)
  3. c语言程序窗口后台持续监测,用c语言实现后台运行的、每隔30s检查一次的、带有日志功能的断网重新连接程序...
  4. 线程基础知识系列(三)线程的同步
  5. 前端学习(3249):react的文件src
  6. xcode8注释快捷键失效问题
  7. html怎么给没张图片添加单击事件,如何在Canvas上的图形/图像绑定事件监听的实现...
  8. 第 9 章 代码审查制度
  9. linux卸载splunk,linux安装splunk-enterprise
  10. python modbus tk 库_python modbus_tk模块学习笔记(rtu slaver例程)
  11. 极域电子教室常见问题的解决方法
  12. 短信验证码 超时 java_短信验证码被刷怎么办?java 短信验证码防刷策略
  13. Nginx代理百度地图,实现内网访问百度地图
  14. 网络性能应用检测系统
  15. 设备故障率高的四大原因及对策分析
  16. 一些http和tomcat知识补充
  17. 浅析eBay联盟营销的上下文广告机制
  18. (小)算法题(长期更新)
  19. Stay hungry stay young
  20. MVG(0)——wt is MVG

热门文章

  1. android视频编辑sdk官网,LanSoEditor_common ---android平台的视频编辑SDK
  2. python输入一个自然数、判断是否为素数_Python编程判断一个正整数是否为素数的示例代码分享...
  3. 2021年塔式起重机司机报名考试及塔式起重机司机考试内容
  4. k8s:open /run/flannel/subnet.env: no such file or directory
  5. 计算机毕业设计 SSM协同过滤算法电影推荐系统 电影在线推荐系统 在线电影点播系统Java Vue MySQL数据库 远程调试 代码讲解
  6. 搞AI开发,你不得不会的PyCharm技术
  7. 2018浙大计算机研究生分数线,2020浙江大学研究生分数线汇总(含2018-2019历年复试)...
  8. 毕业论文的绪论应该怎么写?
  9. 什么是 iCloud 钥匙串?它有什么用以及如何使用它
  10. 峰哥读者从设计转行外包数仓,再跳槽到甲方做大数据开发