一.model.parameters()与model.state_dict()

model.parameters()model.state_dict()都是Pytorch中用于查看网络参数的方法

一般来说,前者多见于优化器的初始化,例如:

后者多见于模型的保存,如:

当我们对网络调参或者查看网络的参数是否具有可复现性时,可能会查看网络的参数

pretrained_dict = torch.load(yolov4conv137weight)model_dict = _model.state_dict()  #查看模型的权重和biass系数pretrained_dict = {k1: v for (k, v), k1 in zip(pretrained_dict.items(), model_dict)}model_dict.update(pretrained_dict) #更新model网络模型的参数的权值和biass,这相当于是一个浅拷贝,对这个更新改变会更改模型的权重和biass

model.state_dict()其实返回的是一个OrderDict,存储了网络结构的名字和对应的参数。

例子:

#encoding:utf-8import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import numpy as mp
import matplotlib.pyplot as plt
import torch.nn.functional as F#define model
class TheModelClass(nn.Module):def __init__(self):super(TheModelClass,self).__init__()self.conv1=nn.Conv2d(3,6,5)self.pool=nn.MaxPool2d(2,2)self.conv2=nn.Conv2d(6,16,5)self.fc1=nn.Linear(16*5*5,120)self.fc2=nn.Linear(120,84)self.fc3=nn.Linear(84,10)def forward(self,x):x=self.pool(F.relu(self.conv1(x)))x=self.pool(F.relu(self.conv2(x)))x=x.view(-1,16*5*5)x=F.relu(self.fc1(x))x=F.relu(self.fc2(x))x=self.fc3(x)return xdef main():# Initialize modelmodel = TheModelClass()#Initialize optimizeroptimizer=optim.SGD(model.parameters(),lr=0.001,momentum=0.9)#print model's state_dictprint('Model.state_dict:')for param_tensor in model.state_dict():#打印 key value字典print(param_tensor,'\t',model.state_dict()[param_tensor].size())#print optimizer's state_dictprint('Optimizer,s state_dict:')for var_name in optimizer.state_dict():print(var_name,'\t',optimizer.state_dict()[var_name])if __name__=='__main__':main()

具体的输出结果如下:可以很清晰的观测到state_dict中存放的key和value的值

Model.state_dict:
conv1.weight     torch.Size([6, 3, 5, 5])
conv1.bias   torch.Size([6])
conv2.weight     torch.Size([16, 6, 5, 5])
conv2.bias   torch.Size([16])
fc1.weight   torch.Size([120, 400])
fc1.bias     torch.Size([120])
fc2.weight   torch.Size([84, 120])
fc2.bias     torch.Size([84])
fc3.weight   torch.Size([10, 84])
fc3.bias     torch.Size([10])
Optimizer,s state_dict:
state    {}
param_groups     [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [367949288, 367949432, 376459056, 381121808, 381121952, 381122024, 381121880, 381122168, 381122096, 381122312]}]

二.torch.load()和load_state_dict()

load_state_dict(state_dict, strict=True)

从 state_dict 中复制参数和缓冲区到 Module 及其子类中

state_dict:包含参数和缓冲区的 Module 状态字典

strict:默认 True,是否严格匹配 state_dict 的键值和 Module.state_dict()的键值

 model = nn.Sequential(self.down1, self.down2, self.down3, self.down4, self.down5, self.neek)pretrained_dict = torch.load(yolov4conv137weight)  #加载已经训练好的模型参数model_dict = model.state_dict()  #查看权重和偏重# 1. filter out unnecessary keys
pretrained_dict = {k1: v for (k, v), k1 in zip(pretrained_dict.items(), model_dict)}# 2. overwrite entries in the existing state dictmodel_dict.update(pretrained_dict)  #更新已有的模型的权重和偏重model.load_state_dict(model_dict)   #将更新后的参数重新加载至网络模型中

官方推荐的方法,只保存和恢复模型中的参数

# save
torch.save(model.state_dict(), PATH)# load
model = MyModel(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

torch.load("path路径")表示加载已经训练好的模型

而model.load_state_dict(torch.load(PATH))表示将训练好的模型参数重新加载至网络模型中

model.parameters(),model.state_dict(),model .load_state_dict()以及torch.load()相关推荐

  1. model.named_parameters()与model.parameters()

    model.named_parameters() 迭代打印model.named_parameters()将会打印每一次迭代元素的名字和param. 并且可以更改参数的可训练属性 from torch ...

  2. pytorch中的model.named_parameters()与model.parameters()

    参考链接:https://www.cnblogs.com/yqpy/p/12585331.html model.named_parameters() 迭代打印model.named_parameter ...

  3. Pytorch:模型的保存与加载 torch.save()、torch.load()、torch.nn.Module.load_state_dict()

    Pytorch 保存和加载模型后缀:.pt 和.pth 1 torch.save() [source] 保存一个序列化(serialized)的目标到磁盘.函数使用了Python的pickle程序用于 ...

  4. model.state_dict和model.parameters和model.named_parameters区别

    model.state_dict和model.parameters和model.named_parameters区别 在pytorch中,针对model,有上述方法,他们都包含模型参数,但是他们有些区 ...

  5. model.parameters()的理解与使用

    model.parameters()保存的是Weights和Bais参数的值. 首先定义一个模型 #design Model class NeuralNetwork(nn.Module):def __ ...

  6. Keras学习笔记---保存model文件和载入model文件

    Keras学习笔记---保存model文件和载入model文件 保存keras的model文件和载入keras文件的方法有很多.现在分别列出,以便后面查询. keras中的模型主要包括model和we ...

  7. 高斯混合模型Gaussian Mixture Model (GMM)——通过增加 Model 的个数,我们可以任意地逼近任何连续的概率密分布...

    从几何上讲,单高斯分布模型在二维空间应该近似于椭圆,在三维空间上近似于椭球.遗憾的是在很多分类问题中,属于同一类别的样本点并不满足"椭圆"分布的特性.这就引入了高斯混合模型.--可 ...

  8. Terracotta Express Model 和 Terracotta Customized Model

    在网上看到很多关于Terracotta快速安装和自定义安装的文章,我觉得我始终无法明白到底两者有什么区别,今天突然仔细地想想,又好像明白了. 关于Terracotta Express Model 假如 ...

  9. ASP.NET MVC基于标注特性的Model验证:一个Model,多种验证规则

    对于Model验证,理想的设计应该是场景驱动的,而不是Model(类型)驱动的,也就是对于同一个Model对象,在不同的使用场景中可能具有不同的验证规则.举个简单的例子,对于一个表示应聘者的数据对象来 ...

最新文章

  1. ASP.NET MVC以ModelValidator为核心的Model验证体系: ModelValidator
  2. hadoop hive hbase 集群搭建
  3. Quartz-TriggerListener解读
  4. 洛谷P1873 砍树
  5. 从 301 跳转,聊聊边缘规则的那些小妙用
  6. C和指针之字符串编程练习11(统计一串字符包含the的个数)
  7. Docker学习文档之一 安装软件-Windows环境
  8. 虚拟机 之 安装VMTools工具
  9. 最优矩阵链乘(动态规划)
  10. ktv服务器管理系统,KTV收银管理系统.doc
  11. 综合类新闻(APP)
  12. Turbo码(Turbo Codes)
  13. 一个很好用的JS在线格式化工具
  14. 爬取淘宝买家秀,sign值的生成
  15. excel保存csv文件数字失真解决办法
  16. EXCEL如何在一个图上画多条曲线
  17. android 安装界面关闭程序,Android安装apk文件,不弹出安装完成的界面
  18. JS港澳台身份证校验
  19. 游戏辅助 -- 走路call中ecx值分析
  20. linux Ubuntu 报错:No command ‘setenv‘ found

热门文章

  1. boost::hana::hash用法的测试程序
  2. boost::iostreams::grep_filter用法的测试程序
  3. boost::histogram::axis::integer用法的测试程序
  4. boost::hana::to用法的测试程序
  5. ITK:将内核应用于非零图像中的每个像素
  6. VTK:Render之RenderView
  7. VTK:IO之ReadDICOMSeries
  8. VTK:Filtering之ProgrammableSource
  9. Qt Linguist TS文件格式
  10. Qt Creator创建组件