参考   state_dict - 云+社区 - 腾讯云

在pytorch中,torch.nn.Module模块中的state_dict变量存放训练过程中需要学习的权重和偏执系数,state_dict作为python的字典对象将每一层的参数映射成tensor张量,需要注意的是torch.nn.Module模块中的state_dict只包含卷积层和全连接层的参数,当网络中存在batchnorm时,例如vgg网络结构,torch.nn.Module模块中的state_dict也会存放batchnorm's running_mean。

torch.optim模块中的Optimizer优化器对象也存在一个state_dict对象,此处的state_dict字典对象包含state和param_groups的字典对象,而param_groups key对应的value也是一个由学习率,动量等参数组成的一个字典对象。因为state_dict本质上Python字典对象,所以可以很好地进行保存、更新、修改和恢复操作(python字典结构的特性),从而为PyTorch模型和优化器增加了大量的模块化。

Sample

通过一个简单的案例来输出state_dict字典对象中存放的变量。

#encoding:utf-8import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import numpy as mp
import matplotlib.pyplot as plt
import torch.nn.functional as F#define model
class TheModelClass(nn.Module):def __init__(self):super(TheModelClass,self).__init__()self.conv1=nn.Conv2d(3,6,5)self.pool=nn.MaxPool2d(2,2)self.conv2=nn.Conv2d(6,16,5)self.fc1=nn.Linear(16*5*5,120)self.fc2=nn.Linear(120,84)self.fc3=nn.Linear(84,10)def forward(self,x):x=self.pool(F.relu(self.conv1(x)))x=self.pool(F.relu(self.conv2(x)))x=x.view(-1,16*5*5)x=F.relu(self.fc1(x))x=F.relu(self.fc2(x))x=self.fc3(x)return xdef main():# Initialize modelmodel = TheModelClass()#Initialize optimizeroptimizer=optim.SGD(model.parameters(),lr=0.001,momentum=0.9)#print model's state_dictprint('Model.state_dict:')for param_tensor in model.state_dict():#打印 key value字典print(param_tensor,'\t',model.state_dict()[param_tensor].size())#print optimizer's state_dictprint('Optimizer,s state_dict:')for var_name in optimizer.state_dict():print(var_name,'\t',optimizer.state_dict()[var_name])if __name__=='__main__':main()Output:
-----------------------------------------------------------------------------------------
Model.state_dict:
conv1.weight     torch.Size([6, 3, 5, 5])
conv1.bias   torch.Size([6])
conv2.weight     torch.Size([16, 6, 5, 5])
conv2.bias   torch.Size([16])
fc1.weight   torch.Size([120, 400])
fc1.bias     torch.Size([120])
fc2.weight   torch.Size([84, 120])
fc2.bias     torch.Size([84])
fc3.weight   torch.Size([10, 84])
fc3.bias     torch.Size([10])
Optimizer,s state_dict:
state    {}
param_groups     [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [367949288, 367949432, 376459056, 381121808, 381121952, 381122024, 381121880, 381122168, 381122096, 381122312]}]
-----------------------------------------------------------------------------------------

state_dict详解相关推荐

  1. 【PyTorch】state_dict详解

    Introduce 在pytorch中,torch.nn.Module模块中的state_dict变量存放训练过程中需要学习的权重和偏执系数,state_dict作为python的字典对象将每一层的参 ...

  2. 【CV】Pytorch一小时入门教程-代码详解

    目录 一.关键部分代码分解 1.定义网络 2.损失函数(代价函数) 3.更新权值 二.训练完整的分类器 1.数据处理 2. 训练模型(代码详解) CPU训练 GPU训练 CPU版本与GPU版本代码区别 ...

  3. 【深度学习】ResNet——CNN经典网络模型详解(pytorch实现)

    建议大家可以实践下,代码都很详细,有不清楚的地方评论区见~ 1.前言 ResNet(Residual Neural Network)由微软研究院的Kaiming He等四名华人提出,通过使用ResNe ...

  4. GoogLeNet——CNN经典网络模型详解(pytorch实现)

    一.前言 论文地址:http://arxiv.org/abs/1602.07261 2014年,GoogLeNet和VGG是当年ImageNet挑战赛(ILSVRC14)的双雄,GoogLeNet获得 ...

  5. Pytorch|YOWO原理及代码详解(二)

    Pytorch|YOWO原理及代码详解(二) 本博客上接,Pytorch|YOWO原理及代码详解(一),阅前可看. 1.正式训练 if opt.evaluate:logging('evaluating ...

  6. AlexNet网络结构详解与代码复现

    参考内容来自up:3.1 AlexNet网络结构详解与花分类数据集下载_哔哩哔哩_bilibili up主的CSDN博客:太阳花的小绿豆的博客_CSDN博客-深度学习,软件安装,Tensorflow领 ...

  7. ResNet网络结构详解,网络搭建,迁移学习

    前言: 参考内容来自up:6.1 ResNet网络结构,BN以及迁移学习详解_哔哩哔哩_bilibili up的代码和ppt:https://github.com/WZMIAOMIAO/deep-le ...

  8. [Pytorch系列-61]:循环神经网络 - 中文新闻文本分类详解-3-CNN网络训练与评估代码详解

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/detai ...

  9. pytorch模型(.pt)转onnx模型(.onnx)的方法详解(1)

    1. pytorch模型转换到onnx模型 2.运行onnx模型 3.比对onnx模型和pytorch模型的输出结果 我这里重点是第一点和第二点,第三部分  比较容易 首先你要安装 依赖库:onnx ...

最新文章

  1. 2018-3-26论文(GWO和WOA)中Table1--Table3中的benchmark函数F1-F23图形
  2. ueditor编辑器和at.js集成
  3. Linux学习命令汇总三——Linux用户组管理,文件权限管理,文本搜索命令grep及正则表达式...
  4. 鸿蒙股票深度分析,本月华为鸿蒙概念股市回顾分析(3月31日)
  5. 服务器性能测试典型工具介绍
  6. 联想Z6 Pro测评:斗战圣佛?很能打!
  7. 为了在简历上写掌握【Java多线程和并发编程】,做了两万字总结
  8. cocos2d-x学习(一) HelloWorld
  9. 唐宇迪学习笔记9:逻辑回归与梯度下降策略
  10. Linux的ssh登录命令,Linux SSH登录命令总结
  11. 同款视频一键制作生成微信小程序源码下载恶搞视频,特效视频,唯美视频等等
  12. 小学身高体重测试软件,儿童身高体重在线测评
  13. axure 倒计时_Axure倒计时效果
  14. 作为无人机方面做嵌入式编写的飞控总结6--IMU惯性系统和GPS导航系统融合小结1(惯性导航算法)
  15. Linux中基于eBPF的恶意利用与检测机制
  16. Win11系统设置自动关机的方法分享
  17. 高中数学数列解题技巧及常用高考数学解题方法
  18. ubuntu snap 安装的nextcloud 忘记管理员密码,重新设置密码。
  19. 手机性能指标详细测试步骤【Android/IOS】
  20. python 桑基图_3行代码基于python的matplotlib绘制桑基图

热门文章

  1. mysql王者晋级 电子书_“MySQL王者晋级之路”读书笔记-结构与引擎
  2. 在线生成装逼图片引流源码
  3. c 语言解析png图片文件信息,利用C/C++二进制读写png文件的方法示例
  4. 红牛农场java代码_实验题目 Java语言概述.doc
  5. python实现简单的猜数字游戏
  6. c语言电脑上有吗,为什么我的C语言程序在我的电脑上有错误,在别人的电脑上没错?...
  7. Silverlight 2.5D RPG游戏技巧与特效处理:(八)无限缩放空间系统
  8. AP广播多VLAN SSID
  9. 南师大计算机学院博雅课的要求,南师大《博雅选课指南》网络热传
  10. 【面试实战】Java面试的时候,你能这么回答,就基本都可以过了!