pytorch1.0网络保存、提取、加载

import torch
import torch.nn.functional as F  # 包含激励函数
import matplotlib.pyplot as plt# 假数据
x = torch.unsqueeze(torch.linspace(-1,1,100),dim=1) # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size())  # noisy y data (tensor), shape=(100, 1)# The code below is deprecated in Pytorch 0.4. Now, autograd directly supports tensors
# x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False)def save():# save net1# 建网络net1 = torch.nn.Sequential(torch.nn.Linear(1, 10),torch.nn.ReLU(),torch.nn.Linear(10, 1))optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)loss_func = torch.nn.MSELoss()# 训练for t in range(100):prediction = net1(x)loss = loss_func(prediction, y)optimizer.zero_grad()loss.backward()optimizer.step()# plot resultplt.figure(1, figsize=(10, 3))plt.subplot(131)plt.title('Net1')plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)# 2 ways to save the nettorch.save(net1, 'net.pkl')  # save entire net # 保存整个网络torch.save(net1.state_dict(), 'net_params.pkl')   # save only the parameters # 只保存网络中的参数 (速度快, 占内存少)# 提取网络
def restore_net():# restore entire net1 to net2net2 = torch.load('net.pkl')prediction = net2(x)# plot resultplt.subplot(132)plt.title('Net2')plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)# 只提取网络参数
def restore_params():# 新建 net3# restore only the parameters in net1 to net3net3 = torch.nn.Sequential(torch.nn.Linear(1, 10),torch.nn.ReLU(),torch.nn.Linear(10, 1))# 将保存的参数复制到 net3# copy net1's parameters into net3net3.load_state_dict(torch.load('net_params.pkl'))prediction = net3(x)# plot resultplt.subplot(133)plt.title('Net3')plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)plt.show()# 保存 net1 (1. 整个网络, 2. 只有参数)
# save net1
save()
# 提取整个网络
# restore entire net (may slow)
restore_net()
# 提取网络参数, 复制到新网络
# restore only the net parameters
restore_params()

转载于:https://www.cnblogs.com/jeshy/p/11199820.html

pytorch1.0神经网络保存、提取、加载相关推荐

  1. TensorFlow2.0 —— 模型保存与加载

    目录 1.Keras版本模型保存与加载 2.自定义版本模型保存与加载 3.总结 1.Keras版本模型保存与加载 保存模型权重(model.save_weights) 保存HDF5文件(model.s ...

  2. tensorflow1.0模型的保存、加载、在训练

    1.checkpoint文件总览 tensorflow保存的模型文件如下所示: .meta文件保存的是图结构,meta文件是pb(protocol buffer)格式文件,包含变量.op.集合等. c ...

  3. TensorFlow2.0:模型的保存与加载

    ** 一.权重参数的保存与加载 ** network.save_weights('weights.ckpt') network.load_weights('weights.ckpt') 权重参数的保存 ...

  4. Pytorch网络模型权重初始化、保存与加载模型、加载预训练模型、按需设置学习率

    前言 在我们对神经网络模型进行训练时,往往需要对模型进行初始化或者加载预训练模型.本文将对模型的权重初始化与加载预训练模型做一个学习记录,以便后续查询使用. 权重初始化 常见的初始化方法 PyTorc ...

  5. tensorflow 1.x Saver(保存与加载模型) 预测

    20201231 tensorflow 1.X 模型保存 https://blog.csdn.net/qq_35290785/article/details/89646248 保存模型 saver=t ...

  6. pytorch模型的保存与加载

    我们先创建一个模型,使用的是pytorch笔记--简易回归问题_刘文巾的博客-CSDN博客 的主体框架,唯一不同的是,我这里用的是torch.nn.Sequential来定义模型框架,而不是那篇博客里 ...

  7. pytorch数据加载、模型保存及加载

    主要涉及的Pytorch官方示例下图红框部分的一些翻译及备注. 1.数据加载及处理   该部分主要是用于进行数据集加载及数据预处理说明,使用的数据集为:人脸+标注坐标.demo程序需要pandas(读 ...

  8. tensor和模型 保存与加载 PyTorch

    PyTorch教程-7:PyTorch中保存与加载tensor和模型详解 保存和读取Tensor PyTorch中的tensor可以保存成 .pt 或者 .pth 格式的文件,使用torch.save ...

  9. (一次性搞定)ORB_SLAM2地图保存与加载

    (一次性搞定)ORB_SLAM2地图保存与加载 本文记录了ORB_SLAM2中地图保存与加载的过程. 参考博客: https://blog.csdn.net/qq_34254510/article/d ...

最新文章

  1. git 源代码自动检查_检查提交(git log,git show)《 Nest.js 应用案例:源代码管理 》...
  2. 使用BurpSuite抓取HTTPS网站
  3. windows下mongodb安装与使用整理
  4. mysql字段简索引_MySQL优化看这一篇就够了
  5. 机器学习03Logistic回归
  6. 从PeopleEditor控件中取出多用户并更新到列表
  7. Oracle tips
  8. [转]linux用户管理
  9. HDU2011 多项式求和【入门】
  10. galera cluster数据备份
  11. Web端高保真动态交互Axure元件库
  12. PDA开发从入门到精通
  13. 《电磁学》学习笔记1——电场
  14. 遗传算法--旅行商问题(TSP问题)-Matlab
  15. 【Java小项目实训】编写一个窗体程序显示的日历 万年历
  16. 金蝶KIS专业版单据序时簿看不到的问题
  17. 【考研】计算机考研复试之智力题测试
  18. 流放者柯南服务器文件,《流放者柯南》个人服务器架设教程文本及视频详解
  19. 红色警戒2修改器原理百科(二)
  20. TensorFlow中相关的维度处理函数

热门文章

  1. js页面传值php页面,不同页面,php如何js传值?
  2. mysql source导入_读取MySQL数据库中的数据【Python数据分析百例连载】
  3. php 计算每年春节日期,动态显示2019年农历春节倒计时—2019年1月21日23时45分
  4. 用screenfetch显示带有酷炫Linux标志的基本硬件信息
  5. 使用nmap查看web服务支持的http methods
  6. django03_表单(forms.ModelForm)(login前后台)
  7. [转贴]人老总是一场空
  8. 计算机网络面试知识点
  9. tensorflow 小于_TensorFlow做Sparse Machine Learning
  10. python 协程 多线程_python进阶之多线程(简单介绍协程)