pytorch1.0神经网络保存、提取、加载
pytorch1.0网络保存、提取、加载
import torch import torch.nn.functional as F # 包含激励函数 import matplotlib.pyplot as plt# 假数据 x = torch.unsqueeze(torch.linspace(-1,1,100),dim=1) # x data (tensor), shape=(100, 1) y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1)# The code below is deprecated in Pytorch 0.4. Now, autograd directly supports tensors # x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False)def save():# save net1# 建网络net1 = torch.nn.Sequential(torch.nn.Linear(1, 10),torch.nn.ReLU(),torch.nn.Linear(10, 1))optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)loss_func = torch.nn.MSELoss()# 训练for t in range(100):prediction = net1(x)loss = loss_func(prediction, y)optimizer.zero_grad()loss.backward()optimizer.step()# plot resultplt.figure(1, figsize=(10, 3))plt.subplot(131)plt.title('Net1')plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)# 2 ways to save the nettorch.save(net1, 'net.pkl') # save entire net # 保存整个网络torch.save(net1.state_dict(), 'net_params.pkl') # save only the parameters # 只保存网络中的参数 (速度快, 占内存少)# 提取网络 def restore_net():# restore entire net1 to net2net2 = torch.load('net.pkl')prediction = net2(x)# plot resultplt.subplot(132)plt.title('Net2')plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)# 只提取网络参数 def restore_params():# 新建 net3# restore only the parameters in net1 to net3net3 = torch.nn.Sequential(torch.nn.Linear(1, 10),torch.nn.ReLU(),torch.nn.Linear(10, 1))# 将保存的参数复制到 net3# copy net1's parameters into net3net3.load_state_dict(torch.load('net_params.pkl'))prediction = net3(x)# plot resultplt.subplot(133)plt.title('Net3')plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)plt.show()# 保存 net1 (1. 整个网络, 2. 只有参数) # save net1 save() # 提取整个网络 # restore entire net (may slow) restore_net() # 提取网络参数, 复制到新网络 # restore only the net parameters restore_params()
转载于:https://www.cnblogs.com/jeshy/p/11199820.html
pytorch1.0神经网络保存、提取、加载相关推荐
- TensorFlow2.0 —— 模型保存与加载
目录 1.Keras版本模型保存与加载 2.自定义版本模型保存与加载 3.总结 1.Keras版本模型保存与加载 保存模型权重(model.save_weights) 保存HDF5文件(model.s ...
- tensorflow1.0模型的保存、加载、在训练
1.checkpoint文件总览 tensorflow保存的模型文件如下所示: .meta文件保存的是图结构,meta文件是pb(protocol buffer)格式文件,包含变量.op.集合等. c ...
- TensorFlow2.0:模型的保存与加载
** 一.权重参数的保存与加载 ** network.save_weights('weights.ckpt') network.load_weights('weights.ckpt') 权重参数的保存 ...
- Pytorch网络模型权重初始化、保存与加载模型、加载预训练模型、按需设置学习率
前言 在我们对神经网络模型进行训练时,往往需要对模型进行初始化或者加载预训练模型.本文将对模型的权重初始化与加载预训练模型做一个学习记录,以便后续查询使用. 权重初始化 常见的初始化方法 PyTorc ...
- tensorflow 1.x Saver(保存与加载模型) 预测
20201231 tensorflow 1.X 模型保存 https://blog.csdn.net/qq_35290785/article/details/89646248 保存模型 saver=t ...
- pytorch模型的保存与加载
我们先创建一个模型,使用的是pytorch笔记--简易回归问题_刘文巾的博客-CSDN博客 的主体框架,唯一不同的是,我这里用的是torch.nn.Sequential来定义模型框架,而不是那篇博客里 ...
- pytorch数据加载、模型保存及加载
主要涉及的Pytorch官方示例下图红框部分的一些翻译及备注. 1.数据加载及处理 该部分主要是用于进行数据集加载及数据预处理说明,使用的数据集为:人脸+标注坐标.demo程序需要pandas(读 ...
- tensor和模型 保存与加载 PyTorch
PyTorch教程-7:PyTorch中保存与加载tensor和模型详解 保存和读取Tensor PyTorch中的tensor可以保存成 .pt 或者 .pth 格式的文件,使用torch.save ...
- (一次性搞定)ORB_SLAM2地图保存与加载
(一次性搞定)ORB_SLAM2地图保存与加载 本文记录了ORB_SLAM2中地图保存与加载的过程. 参考博客: https://blog.csdn.net/qq_34254510/article/d ...
最新文章
- git 源代码自动检查_检查提交(git log,git show)《 Nest.js 应用案例:源代码管理 》...
- 使用BurpSuite抓取HTTPS网站
- windows下mongodb安装与使用整理
- mysql字段简索引_MySQL优化看这一篇就够了
- 机器学习03Logistic回归
- 从PeopleEditor控件中取出多用户并更新到列表
- Oracle tips
- [转]linux用户管理
- HDU2011 多项式求和【入门】
- galera cluster数据备份
- Web端高保真动态交互Axure元件库
- PDA开发从入门到精通
- 《电磁学》学习笔记1——电场
- 遗传算法--旅行商问题(TSP问题)-Matlab
- 【Java小项目实训】编写一个窗体程序显示的日历 万年历
- 金蝶KIS专业版单据序时簿看不到的问题
- 【考研】计算机考研复试之智力题测试
- 流放者柯南服务器文件,《流放者柯南》个人服务器架设教程文本及视频详解
- 红色警戒2修改器原理百科(二)
- TensorFlow中相关的维度处理函数
热门文章
- js页面传值php页面,不同页面,php如何js传值?
- mysql source导入_读取MySQL数据库中的数据【Python数据分析百例连载】
- php 计算每年春节日期,动态显示2019年农历春节倒计时—2019年1月21日23时45分
- 用screenfetch显示带有酷炫Linux标志的基本硬件信息
- 使用nmap查看web服务支持的http methods
- django03_表单(forms.ModelForm)(login前后台)
- [转贴]人老总是一场空
- 计算机网络面试知识点
- tensorflow 小于_TensorFlow做Sparse Machine Learning
- python 协程 多线程_python进阶之多线程(简单介绍协程)