self.lstm = nn.LSTM(input_size=n_class, hidden_size=n_hidden, num_layers=n_layers)
输入网络的维度数:26,隐层维度:128,lstm层数:n_layers:1

LSTM: 单向LSTM,D=1
input:[3, 10, 26] sequence len=3, batch=10, input_size=26
# 可以理解为,sequence length有多长就有几个上图中的F
h0:[1, 10, 128] n_layers = 1, batch=10, hidden_size=128
c0:[1, 10, 128] n_layers = 1, batch=10, hidden_size=128
输入,3个细胞单元作为输入,每个细胞单元接收26维的输入,一个batch为10条数据

输出:output, (h_n, c_n)
output: [3, 10, 128],3个细胞单元,一批次10个数据,隐层维度为128,输出最后一层每个细胞单元的隐层输出;

output保存了最后一层,每个time step的输出h
h_n保存每一层,最后一个time-step的输出h
c_n保存每一层,最后一个time-step的输出c

官方给定实例:

# Examples::# 初始化模型
rnn = nn.LSTM(10, 20, 2) # 输入特征的维度:10,隐层维度:20 ,LSTM层数:2
# 输入数据格式
input = torch.randn(5, 3, 10)  # 句子长度:5,batch_size:3,输入特征的维度:10
# 隐层结构
h0 = torch.randn(2, 3, 20)  # LSTM层:2,batch_size:3,隐层参数维度:20
# 细胞单元结构
c0 = torch.randn(2, 3, 20)  # 同上
# 输出结构
output, (hn, cn) = rnn(input, (h0, c0))
"""output:LSTM最后一层隐状态的输出,维度:(5, 3, 20)hn:最后一个timestep的隐状态结果,维度:(1, 3, 20)cn:最后一个timestep的细胞单元的结果,维度:(1, 3, 20)"""

BiLSTM-torch调用:

# 单层变双层,单向变为双向
# 举例解决:
# 单层Bilstm, 输入数据,一个batch包含26条数据,每条数据为27x27,类别数为27,每个单元包含隐藏单元为5
# 定义BiLSTM时:
bilstm = nn.LSTM(input_size=27, hidden_size=5, bidirectional=True)
# 输入input维度应为:[sequence length, batch size, input_size]即需通过torch.transpose(0, 1)改变维度
inputs shape: [27, 26, 27]
# 隐藏定义:[Bilstm:2 x 层数默认:1, batch_size:26,  每个单元包含隐藏单元:5]
h0 shape: [2x1, 26, 5]
c0 shape: [2x1, 26, 5] #细胞状态同上output, (h_n, c_n) = bilstm(inputs, (h0, c0)
#同lstm,output包含最后一层所有细胞的隐层输出,
# [sequence length, batch size, H_{out}*2]
output shape: [27, 26, 5*2]
# h_n 包含最后一个细胞的最后一个时间步的隐层输出
# [D * num_layers, batch size, H_{out}]
h_n shape: [2, 26, 5]
# c_n同上
c_n shape: [2, 26, 5]

pytorch实现LSTM例子:

import numpy as np
import torch
import torch.nn as nn
from torch.nn.modules import loss, module
import torch.optim as optimdef make_batch():input_batch, target_batch = [], []for seq in seq_data:input = [word_dict[n] for n in seq[:-1]]  #到最后一个字母target = word_dict[seq[-1]]  # 最后一个标签,预测最后一个单词input_batch.append(np.eye(n_class)[input])target_batch.append(target)print("input:{}, target:{}".format(input, target))return input_batch, target_batch"""np.eye(3):代表一个维度为3的单位矩阵np.eye(x)[N]: 代表取出维度为x的单位矩阵的第N行,用于构造one-hot表示x.transpose(0, 1):交换0,1两个维度例:x为[10, 3, 26] x.transpose(0, 1)为[3, 10 ,26]  # 有啥意义Examples::>>> rnn = nn.LSTM(10, 20, 2)  #input_size: The number of expected features in the input `x`hidden_size: The number of features in the hidden state `h`num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``>>> input = torch.randn(5, 3, 10) :  N,L,H_{in}或L,N,H_{in}输入向量的长度, 一次传入多少条数据,单词向量的维度L ={} & \text{sequence length} \\N ={} & \text{batch size} \\H_{in} ={} & \text{input\_size} \\>>> h0 = torch.randn(2, 3, 20)(D * \text{num\_layers}, N, H_{out}), BiLSTM: D=2, LSTM: D=1H_{out} ={} & \text{proj\_size if } \text{proj\_size}>0 \text{ otherwise hidden\_size} \\\end{aligned}>>> c0 = torch.randn(2, 3, 20)(D * \text{num\_layers}, N, H_{cell})H_{cell} ={} & \text{hidden\_size} \\>>> output, (hn, cn) = rnn(input, (h0, c0))output保存了最后一层,每个time step的输出h,如果是双向LSTM,每个time step的输出h = [h正向, h逆向] (同一个time step的正向和逆向的h连接起来)。h_n保存了每一层,最后一个time step的输出h,如果是双向LSTM,单独保存前向和后向的最后一个time step的输出h。c_n与h_n一致,只是它保存的是c的值。output是一个三维的张量,第一维表示序列长度,第二维表示一批的样本数(batch),第三维是 hidden_size(隐藏层大小) * num_directions
"""class TextLSTM(nn.Module):def __init__(self):super(TextLSTM, self).__init__()self.lstm = nn.LSTM(input_size=n_class, hidden_size=n_hidden)# input_size:26# hidden_size:128# mode:'LSTM'# num_layers:1self.W = nn.Linear(n_hidden, n_class, bias=False)  #LSTM+一个线性层加偏置,最后softmax输出 # out_features:26 # in_features:128self.b = nn.Parameter(torch.ones([n_class]))  # P的大小写有何影响,nn.parameter报错,必须得是nn.Parameterdef forward(self, x):# print('x:{}, x_size:{}'.format(#     x, x.size()))  # shape:torch.Size([10, 3, 26])input = x.transpose(0, 1)  # torch.Size([3, 10, 26])  ****有何意义# print('input:{}, input_size:{}'.format(input, input.size()))hidden_state = torch.zeros(1, len(x),n_hidden)  # shape:torch.Size([1, 10, 128])cell_state = torch.zeros(1, len(x),n_hidden)  # shape:torch.Size([1, 10, 128])# 为啥少了一个维度啊????outputs, (_, _) = self.lstm(input, (hidden_state, cell_state))# _1,2: torch.Size([1, 10, 128])# print("outputs_size:{}".format(#     outputs.size()))  # torch.Size([3, 10, 128])outputs = outputs[-1]  # outputs: torch.Size([10, 128])# 使用outputs的原因:保留最后一层所有隐层的输出,可以尽可能包含所有的局部信息,减少RNN的梯度消失、最后输出的向量与较后的cell有关model = self.W(outputs) + self.b  # shape:torch.Size([10, 26])  batch_size, n_class# print("model:{}".format(model))return modelif __name__ == '__main__':n_step = 3n_hidden = 128char_arr = [c for c in 'abcdefghijklmnopqrstuvwxyz']word_dict = {w: i for i, w in enumerate(char_arr)}number_dict = {i: w for i, w in enumerate(char_arr)}print(number_dict)# print(word_dict)n_class = len(word_dict)seq_data = ['make', 'need', 'coal', 'word', 'love', 'hate', 'live', 'home', 'hash','star']model = TextLSTM()criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=0.001)input_batch, target_batch = make_batch()  # 原本shape[10, 3], [10]input_batch = torch.FloatTensor(input_batch)  # shape:torch.Size([10, 3, 26])  10条数据,每条数据长度为3,one-hot表示维度为26target_batch = torch.LongTensor(target_batch)for epoch in range(1000):optimizer.zero_grad()output = model(input_batch)loss1 = criterion(output, target_batch)if (epoch + 1) % 100 == 0:print('Epoch:', '%04d' % (epoch + 1), 'cost =','{:.6f}'.format(loss1))loss1.backward()optimizer.step()inputs = input_batch[:4]# inputs = inputs.unsqueeze(0)print('inputs:{}, inputs_size:{}'.format(inputs, np.shape(inputs)))print("mdoel:{}".format(model))  #model就是一个输入26维,输出128维的lstmprint("mdoel(inputs):{}".format(model(inputs)))predict = model(inputs).data.max(1, keepdim=True)[1]# output = model(input_batch)print("predict:{}".format(predict))

pytorch_LSTM:参数相关推荐

  1. 在Dockerfile中设置G1垃圾回收器参数

    在Dockerfile中设置G1垃圾回收器参数 ENV JAVA_OPTS="\ -server \ -XX:SurvivorRatio=8 \ -XX:+DisableExplicitGC ...

  2. Java Calendar.add()方法的使用,参数含义。指定时间差。

    cal.add()方法中的参数含义: 第一个参数如果是1则代表的是对年份操作,2是对月份操作,3是对星期操作,5是对日期操作,11是对小时操作,12是对分钟操作,13是对秒操作,14是对毫秒操作. 第 ...

  3. java带参数的方法笔记_具有Java参数的方法的类声明

    类声明可以包含在Java中具有参数的方法.演示此过程的程序如下: 示例class Message { public void messagePrint(String msg) { System.out ...

  4. Gin 框架学习笔记(02)— 参数自动绑定到结构体

    参数绑定模型可以将请求体自动绑定到结构体中,目前支持绑定的请求类型有 JSON .XML .YAML 和标准表单 form数据 foo=bar&boo=baz 等.换句话说,只要定义好结构体, ...

  5. VS Code 配置调试参数、launch.json 配置文件属性、task.json 变量替换、自动保存并格式化、空格和制表符、函数调用关系、文件搜索和全局搜索、

    1. 生成配置参数 对于大多数的调试都需要在当前项目目录下创建一个 lanch.json 文件,位置是在当前项目目录下生成一个 .vscode 的隐藏文件夹,在里面放置一些配置内容,比如:settin ...

  6. VS Code 安装插件、自定义模板、自定义配置参数、自定义主题、配置参数说明、常用的扩展插件

    1. 下载和官网教程 下载地址:https://code.visualstudio.com/ 官方教程:https://code.visualstudio.com/docs 2. 安装插件 安装扩展插 ...

  7. 浅显易懂 Makefile 入门 (08)— 默认 shell (/bin/sh)、命令回显、make参数(-n 只显示命令但不执行,-s 禁止所有回显)、单行命令、多行命令、并发执行

    1. shell 相关 1.1 默认 shell Makefile 所使用的命令是由 shell 命令行组成,他们是一条一条执行的. 多个命令之间要使用分号隔开,Makefile 中的任何命令都要以 ...

  8. Go 学习笔记(65)— Go 中函数参数是传值还是传引用

    Go 语言中,函数参数传递采用是值传递的方式.所谓"值传递",就是将实际参数在内存中的表示逐位拷贝到形式参数中.对于像整型.数组.结构体这类类型,它们的内存表示就是它们自身的数据内 ...

  9. Go 学习笔记(61)— Go 高阶函数、函数作为一等公民(函数作为输入参数、返回值、变量)的写法

    函数在 Go 语言中属于"一等公民(First-Class Citizen)"拥有"一等公民"待遇的语法元素可以如下使用 可以存储在变量中: 可以作为参数传递给 ...

最新文章

  1. UITableview 多行删除
  2. 第五章5.1 strace
  3. python学好了能干什么-Python语言能做什么,学好能干什么
  4. linux maven .m2文件夹,Maven .m2文件夹创建(示例代码)
  5. linux定时关机命令_win10电脑定时关机命令
  6. WPF捕获未处理的异常
  7. 20191219每日一句
  8. 细数国内无人机的江湖门派
  9. 关于Google Chrome浏览器离线安装包下载方法
  10. 编译go文件时内部包引用受限的问题(use of internal package /PATH/ not allowed)
  11. 驱动程序如何手动卸载与更新
  12. 2021-08-09[RoarCTF2019]黄金6年、从娃娃抓起
  13. matlab的omega0是什么,【铁虫】我喜欢的Omega被别人标记了怎么办(内有嘟嘟)
  14. SpringDataJPA -06- specification的基本使用
  15. Apache Tuscany中文论坛开通: http://groups.google.com/group/tuscany-sca-chinese
  16. OpenGL---GLUT(一)
  17. python 百度搜索结果 浏览器 和终端不一致_python自动爬取百度搜索结果
  18. 韩国渠道接入三星支付(Android 接入 Samsung in app purchase)
  19. 深大uooc大学生心理健康章节答案第八章
  20. 混合云风头正劲 青云QingCloud为何成为领导者?

热门文章

  1. 如何把照片转换成jpg格式呢?
  2. Leetcode 2233. Maximum Product After K Increments
  3. Tomcat 3、4、5、6、7、8、9 各版本下载地址
  4. java数组重置_Java:如何重置数组列表,使其为空
  5. 集成学习(随机森林)
  6. json数组的遍历(获取属性名和属性值)
  7. ​金融风控的护航员——聊聊ERNIE在度小满用户风控的应用
  8. 熬夜整理Java面试笔试题,你还看不懂吗?
  9. matlab 去条带噪声,一种图像条带噪声及坏线消除方法
  10. Endnote 导入enw文件无响应及解决方法