目录

1 parameters()

1.1 model.parameters():

1.2 model.named_parameters():

2 state_dict()


torch.nn.Module 模块中的可学习参数都被包含在该模型的parameters 中,可以通过model.parameters()的方法获取;

state_dict()是一 个字典,包含了模型各的参数(tensor类型),多用于保存模型;

1 parameters()

1.1 model.parameters():

源码:

    def parameters(self, recurse: bool = True) -> Iterator[Parameter]:r"""Returns an iterator over module parameters.返回模块参数上的迭代器。This is typically passed to an optimizer.这通常被传递给优化器Args:recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.如果为True, 则生成该模块 及其所有子模块的参数。否则,只生成该模块的直接成员的形参。Yields:Parameter: module parameterExample::>>> for param in model.parameters():>>>     print(type(param), param.size())<class 'torch.Tensor'> (20L,)<class 'torch.Tensor'> (20L, 1L, 5L, 5L)"""for name, param in self.named_parameters(recurse=recurse):yield param

可以通过Module.parameters()获得网络参数, 迭代的返回模型所有可学习的参数 --  是个生成器

有些layer不包含可学习的参数,比如(relu, maxpool),因此model.parameters()不会输出这些层;

parameters()多见于优化器的初始化;

由于parameters()是生成器,因此需要利用循环或者next()来获取数据:

例子:

>>> import torch
>>> import torch.nn as nn>>> class Net(nn.Module):
...     def __init__(self):
...             super().__init__()
...             self.linear = nn.Linear(2,2)
...     def forward(self,x):
...             out = self.linear(x)
...             return out
...
>>> net = Net()
>>> for para in net.parameters():
...     print(para)
... Parameter containing:
tensor([[-0.1954, -0.2290],[ 0.5897, -0.3970]], requires_grad=True)
Parameter containing:
tensor([-0.1808,  0.2044], requires_grad=True)>>> for para in net.named_parameters():
...     print(para)
...
('linear.weight', Parameter containing:
tensor([[-0.1954, -0.2290],[ 0.5897, -0.3970]], requires_grad=True))
('linear.bias', Parameter containing:
tensor([-0.1808,  0.2044], requires_grad=True))

1.2 model.named_parameters():

是带有layer name的model.parameters(),其以tuple方式输出,其中包含两个元素,分别为layer name和 model.parameters;

layer name有后缀 .weight, .bias用于区分权重和偏置;

源码:

    def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:r"""Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.返回模块参数上的迭代器,生成参数名和参数本身。Args:prefix (str): prefix to prepend to all parameter names.recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.如果为True,则生成该模块及其所有子模块的参数。否则,只生成该模块的直接成员的形参。Yields:(string, Parameter): Tuple containing the name and parameterExample::>>> for name, param in self.named_parameters():>>>    if name in ['bias']:>>>        print(param.size())"""gen = self._named_members(lambda module: module._parameters.items(),prefix=prefix, recurse=recurse)for elem in gen:yield elem

代码例子,看1.1部分;

2 state_dict()

model.state_dict()能够获得模型所有的参数,包括可学习的参数和不可学习的参数,返回值是一个有序字典OrderedDict.

这部分相当于在model.parameters()基础上,又额外获取了不可学习的参数部分;

例子:

key值是对网络参数的说明,这里是线性层的weight和bias;

>>> class Net(nn.Module):
...     def __init__(self):
...             super().__init__()
...             self.linear = nn.Linear(10,8)
...             self.dropout = nn.Dropout(0.5)
...             self.linear1 = nn.Linear(8,2)
...     def forward(self,x):
...             out = self.dropout(self.linear(x))
...             out = self.linear1(out)
...             return out
...
>>> net = Net()
>>> net.state_dict()
OrderedDict([('linear.weight', tensor([[ 0.1415, -0.2228, -0.1262,  0.0992, -0.1600,  0.0141, -0.1841, -0.1907,0.0295, -0.1853],[-0.0399, -0.2487, -0.3085,  0.1602,  0.3135,  0.1379,  0.0696,  0.0362,-0.1619, -0.0887],[-0.1244, -0.1739,  0.1211, -0.2578, -0.0561,  0.0635, -0.1976, -0.2557,0.1761,  0.2553],[ 0.0912, -0.1469, -0.3012, -0.1583, -0.0028,  0.2697,  0.1947, -0.0596,-0.2144, -0.0785],[-0.1770,  0.0411,  0.1663,  0.1861,  0.2769,  0.0990,  0.1883, -0.1801,0.2727,  0.1219],[-0.1269,  0.0713,  0.2798,  0.1760,  0.0965,  0.1144,  0.2644,  0.0274,0.0034,  0.2702],[ 0.0628,  0.0682, -0.1842,  0.1461,  0.0678, -0.2264, -0.1249, -0.1715,0.1115,  0.2459],[ 0.1198, -0.2584,  0.0234,  0.2756,  0.1174, -0.1212,  0.3024, -0.2304,-0.2950,  0.0970]])), ('linear.bias', tensor([-0.3036, -0.1933,  0.2412,  0.3137, -0.3007,  0.2386, -0.1975,  0.3127])), ('linear1.weight', tensor([[-0.1725,  0.3027,  0.1985,  0.1394, -0.1245,  0.2913,  0.0136,  0.1633],[-0.1558, -0.0865, -0.3032,  0.1374,  0.2967, -0.2886,  0.0430, -0.1246]])), ('linear1.bias', tensor([-0.1232, -0.0690]))])
>>>

参考:PyTorch中model.state_dict(),model.modules(),model.children(),model.named_children()等含义_yaoyz105的博客-CSDN博客_model.state_dict()

model.parameters()与model.state_dict() - 知乎

pytorch - state_dict() , parameters() 详解相关推荐

  1. pytorch MSELoss参数详解

    pytorch MSELoss参数详解 import torch import numpy as np loss_fn = torch.nn.MSELoss(reduce=False, size_av ...

  2. pytorch实战:详解查准率(Precision)、查全率(Recall)与F1

    pytorch实战:详解查准率(Precision).查全率(Recall)与F1 1.概述 本文首先介绍了机器学习分类问题的性能指标查准率(Precision).查全率(Recall)与F1度量,阐 ...

  3. PyTorch Python API详解大全(持续更新ing...)

    诸神缄默不语-个人CSDN博文目录 具体内容以官方文档为准. 最早更新时间:2021.4.23 最近更新时间:2023.1.9 文章目录 0. 常用入参及函数统一解释 1. torch 1.1 Ten ...

  4. 【小白学PyTorch】10.pytorch常见运算详解

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 <<小白学PyTorch>> 参考目录: ...

  5. 3 矩阵运算_小白学PyTorch——pytorch常见运算详解

    公众号关注 "DL-CVer" 设为 "星标",DLCV消息即可送达! 参考目录: 1 矩阵与标量 2 哈达玛积 3 矩阵乘法 4 幂与开方 5 对数运算 6 ...

  6. python训练手势分类器_使用Pytorch训练分类器详解(附python演练)

    [前言]:你已经了解了如何定义神经网络,计算loss值和网络里权重的更新.现在你也许会想数据怎么样? 目录: 一.数据 二.训练一个图像分类器 使用torchvision加载并且归一化CIFAR10的 ...

  7. PyTorch 的 Autograd详解

    ↑ 点击蓝字 关注视学算法 作者丨xiaopl@知乎 来源丨https://zhuanlan.zhihu.com/p/69294347 编辑丨极市平台 PyTorch 作为一个深度学习平台,在深度学习 ...

  8. Pytorch LSTM初识(详解LSTM+torch.nn.LSTM()实现)1

    pytorch  LSTM1初识 目录 pytorch  LSTM1初识 ​​​​​​​​​​​​​​​​​​​​​ 一.LSTM简介1

  9. win10开始不显示python_win10从零安装配置pytorch全过程图文详解

    1.安装anaconda (anaconda内置python在内的许多package,所以不用另外下载python) 可以点击下面的清华开源软件镜像站,在官网下载anaconda不如在这下的快 htt ...

最新文章

  1. 视频色彩校正简介 Introduction to Video Color Correction
  2. Elasticsearch创建雇员目录
  3. ObjectModel QML类型
  4. 2020年物联网网络容量至少是目前的1000倍
  5. SQL Server ldf 丢失的数据库恢复
  6. 数学建模算法:支持向量机_从零开始的算法:支持向量机
  7. 看见到洞见之引子(二)机器学习算法
  8. 黄聪:AngularJS最理想开发工具WebStorm
  9. Linux下文件内容查阅命令
  10. android 崩溃捕获框架,DefenseCrash
  11. UEditor实战分享(二)定制
  12. 泛微oa系统什么框架_泛微OA ecology 二次开发实例 开发完整说明
  13. JN5169_EEPROM_PDM
  14. 角度单位中角分、角秒的进制转换
  15. Si5341时钟芯片使用说明
  16. 1467 A. Wizard of Orz
  17. 新版iTunes如何设置手机铃声
  18. 标准H.460公私网穿越视频解决方案
  19. 数字图像处理与Python实现-图像几何变换-图像金字塔
  20. python类型转换方法_详解python中的类型转换方法

热门文章

  1. ORA-32036: 不支持 WITH 子句中串联式查询名的形式 后台报错问题
  2. 设森林F对应的二叉树为B,它有m个结点,B的根p,p右子树结点个数n,森林F中第一棵树的结点个数
  3. 私域流量时代,微商城的五大影响力
  4. 如水晶般晶莹、匀称的时间——奇特的一生
  5. 侦查与打击型(ajax,遵义县公安局创新侦查模式提升打击效能
  6. 【macOS Qt MenuBar】的显示方法
  7. 莫比乌斯反演--懵逼反演系列
  8. python frame框架抓取_Python抓取框架Scrapy爬虫入门:页面提取
  9. 社科院与杜兰大学金融管理硕士项目——人生没有太晚的开始,不要过早的放弃
  10. ffmpeg图片转视频