官方文档在这里。

GRU具体不做介绍了,本篇只做pytorch的API使用介绍.


torch.nn.GRU(*args, **kwargs)

公式


下面公式忽略bias,由于输入向量的长度和隐藏层特征值长度不一致,所以每个公式的W都按x和h分开。这跟理论公式部分有一些具体的实践上区别。

  • reset gate, 重置门
    rt=σ(Wirxt+Whrht−1)r_t = \sigma(W_{ir}x_t+W_{hr}h_{t-1})rt​=σ(Wir​xt​+Whr​ht−1​) GRU里的参数是WirW_{ir}Wir​ 和WirW_{ir}Wir​
  • update gate,更新门
    zt=σ(Wizxt+Whzht−1)z_t = \sigma(W_{iz}x_t+W_{hz}h_{t-1})zt​=σ(Wiz​xt​+Whz​ht−1​) GRU里的参数是WizW_{iz}Wiz​ 和WhzW_{hz}Whz​
  • 更新状态阈值
    nt=tanh(Winxt+rt(Whnht−1))n_t = tanh (W_{in}x_t+r_t(W_{hn} h_{t-1}))nt​=tanh(Win​xt​+rt​(Whn​ht−1​)) GRU里的参数是WinW_{in}Win​ 和WhnW_{hn}Whn​
    这里同LSTM里的g(t)g(t)g(t)函数,只是多了重置门对ht−1h_{t-1}ht−1​的影响
  • 更新hth_tht​
    ht=(1−zt)nt+ztht−1h_t = (1-z_t)n_t + z_t h_{t-1}ht​=(1−zt​)nt​+zt​ht−1​

所以从输入张量和隐藏层张量来说,一共有两组参数(忽略bias参数)

  1. input 组 {WirW_{ir}Wir​ WizW_{iz}Wiz​ WinW_{in}Win​}
  2. hidden组 {WirW_{ir}Wir​ WhzW_{hz}Whz​ WhnW_{hn}Whn​ }


因为hidden size为隐藏层特征输出长度,所以每个参数第一维度都是hidden size;然后每一组是把3个张量按照第一维度拼接,所以要乘以3

举例代码

from torch import nngru = nn.GRU(input_size=3, hidden_size=5, num_layers=1, bias=False)print('weight_ih_l0.shape = ', gru.weight_ih_l0.shape, ', weight_hh_l0.shape = ' , gru.weight_hh_l0.shape)

双向GRU

如果要实现双向的GRU,只需要增加参数bidirectional=True

但是参数并没有增加。

from torch import nngru = nn.GRU(input_size=3, hidden_size=5, num_layers=1, bidirectional=True, bias=False)print('weight_ih_l0.shape = ', gru.weight_ih_l0.shape, ', weight_ih_l0_reverse.shape = ', gru.weight_ih_l0_reverse.shape,'\nweight_hh_l0.shape = ' , gru.weight_hh_l0.shape, ', weight_hh_l0_reverse.shape = ', gru.weight_hh_l0_reverse.shape)

多层的概念

可以参考这里 https://blog.csdn.net/mimiduck/article/details/119975080

【pytorch】nn.GRU的使用相关推荐

  1. PyTorch nn.GRU 使用详解

    我们看官方文档一些参数介绍,以及如下一个简单例子: 看完之后,还是一脸懵逼: 输入什么鬼? 输出又什么鬼? (这里我先把官网中 h0 去掉了,便于大家先理解更重要的概念) import torch f ...

  2. pytorch笔记:torch.nn.GRU torch.nn.LSTM

    1 函数介绍 (GRU) 对于输入序列中的每个元素,每一层计算以下函数: 其中是在t时刻的隐藏状态,是在t时刻的输入.σ是sigmoid函数,*是逐元素的哈达玛积 对于多层GRU 第l层的输入(l≥2 ...

  3. pytorch nn.Embedding

    pytorch nn.Embedding class torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_n ...

  4. Pytorch.nn.Linear 解析(数学角度)

    pytorch.nn.Linear 是一个类,下面是它的一些初始化参数 in_features : 输入样本的张量大小 out_features : 输出样本的张量大小 bias : 偏置 它主要是对 ...

  5. Pytorch GRU(详解GRU+torch.nn.GRU()实现)

    pytorch GRU 目录 pytorch GRU 一.GRU简介1 二.GRU简介2 三.pytorch GRU 3.1    定义GRU ()

  6. pytorch nn.LSTM()参数详解

    输入数据格式: input(seq_len, batch, input_size) h0(num_layers * num_directions, batch, hidden_size) c0(num ...

  7. pytorch系列 -- 9 pytorch nn.init 中实现的初始化函数 uniform, normal, const, Xavier, He initialization...

    本文内容: 1. Xavier 初始化 2. nn.init 中各种初始化函数 3. He 初始化 torch.init https://pytorch.org/docs/stable/nn.html ...

  8. Pytorch nn.Transformer的mask理解

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨林小平@知乎(已授权) 来源丨https://zhuanlan ...

  9. PyTorch nn.Module 一些疑问

    在阅读书籍时,遇到了一些不太理解,或者介绍的不太详细的点. 从代码角度学习理解Pytorch学习框架03: 神经网络模块nn.Module的了解. Pytorch 03: nn.Module模块了解 ...

最新文章

  1. bzoj 2375: 疯狂的涂色
  2. 使用 flex 实现 5 种常用布局
  3. 架构模式: 事务日志跟踪
  4. Arm Linux 交叉编译(交叉编译是什么?CROSS_COMPILE)(交叉编译工具链【待更】)
  5. spring boot基础教程:入门程序Hello World的编写
  6. 你的微信,到底「连接」多少人?
  7. dedemodule.class.php,DEDECMS5.7模块/模块管理列表显示空白问题解决方法
  8. 【零基础学Java】—继承的概述(十九)
  9. bzoj 1682: [Usaco2005 Mar]Out of Hay 干草危机(最小生成树)
  10. Hyperledger02
  11. java hd sex_Java学习笔记(十八)——Java DTO
  12. 企业微信好友无上限,私域流量即将迎来春天?
  13. 输入五个国家的名称按字母顺序排列输出
  14. Ceph Peering以及数据均衡的改进思路
  15. 11.12. ACLs
  16. 偏偏在面试的时候踏入一个大坑--360浏览器兼容模式
  17. 在线透明favicon ico图标文件制作 - aTool在线工具
  18. 什么是数字签名?(内含漫画图解)
  19. 51单片机内部外设:实时时钟(SPI)
  20. 线性代数 线性相关与线性表示的理解

热门文章

  1. Genome Biology | 药物基因组学数据库
  2. 广东海洋大学计算机科学与技术排名,最新排名!广东高校22个学科位居全球前50位...
  3. GEO数据挖掘(3)-芯片基础知识
  4. Nature调查 l 中国博士生们的科研围城
  5. 免费申领Bio-protocol单细胞研究实验方法精选集
  6. 高分文章精选 | 纳米孔宏基因组测序的表现
  7. 微生物组-扩增子16S分析研讨会(2020.1)
  8. 一人一天发两篇Science,视频揭秘:植物如何在与病菌的斗争中取胜?
  9. 水稻微生物组时间序列分析精讲1-模式图与主坐标轴分析
  10. 不属于JAVA类中的变量_在Java中,不属于整数类型变量的是( )。_学小易找答案...