1 导入库函数

import torch
import numpy as np
import matplotlib.pyplot as plt

2 设置超参数

TIME_STEP=10
INPUT_SIZE=1
HIDDEN_SIZE=32
LR=0.02

3  定义RNN

class RNN(torch.nn.Module):def __init__(self):super(RNN,self).__init__()self.rnn=torch.nn.RNN(input_size=INPUT_SIZE,hidden_size=HIDDEN_SIZE,num_layers=1,batch_first=True)
#设置batch_first为True,那么输入数据的维度为(batch_size,time_step,input_size)
#如果不设置这个值,或者设置为False,那么输入数据的维度为(time_step,batch_size,input_size)self.out=torch.nn.Linear(HIDDEN_SIZE,1)#将隐藏层输出转化为需要的输出def forward(self,x,h_state):#因为在RNN中,下一个时间片隐藏层状态的计算需要上一个时间片的隐藏层状态,所以我们要一直传递这个h_state#x (batch_size,time_step,INPUT_SIZE)r_out,h_state=self.rnn(x,h_state)#h_state也要作为RNN的一个输入和一个输出#r_out:(batch_size,time_step,HIDDEN_SIZE)#h_state:(batch_size,time_step,HIDDEN_SIZE)outs=[]for time_step in range(r_out.size()[1]):outs.append(self.out(r_out[:,time_step,:]))#每一个要被self.out运算的元素[batch_size,1,HIDDEN_SIZE]#每个计算完,被append到outs的元素[batch_size,1,1]return torch.stack(outs,dim=1),h_state#返回的第一个元素[batch_size,time_step,1]#torch.stack函数的维度和axis不一样,dim=1的意思是在第一个维度处叠加rnn=RNN()
print(rnn)
'''
RNN((rnn): RNN(1, 32, batch_first=True)(out): Linear(in_features=32, out_features=1, bias=True)
)
'''

或者foward函数也可以这么写:

class RNN(torch.nn.Module):def __init__(self):super(RNN,self).__init__()self.rnn=torch.nn.RNN(input_size=INPUT_SIZE,hidden_size=HIDDEN_SIZE,num_layers=1,batch_first=True)
#设置batch_first为True,那么输入数据的维度为(batch,time_step,input_size)
#如果不设置这个值,或者设置为False,那么输入数据的维度为(time_step,batch,input_size)self.out=torch.nn.Linear(HIDDEN_SIZE,1)def forward(self,x,h_state):r_out,h_state=self.rnn(x,h_state)#在此之前的部分不动r_out=r_out.view(-1,HIDDEN_SIZE)out=self.out(r_out)out=out.view(-1,TIME_STEP,1)return(out,h_state)rnn=RNN()
print(rnn)

4 设置优化器和损失函数

optimizer=torch.optim.Adam(rnn.parameters(),lr=LR)
loss_func=torch.nn.MSELoss()

5 训练RNN

我们这里希望用sin函数预测cos函数

h_state=Nonefor step in range(100):start=step*np.piend=(step+1)*np.pisteps=np.linspace(start,end,TIME_STEP,dtype=np.float32)#这里dtype这一部分一定要加,不然的话会报错,RuntimeError: expected scalar type Double but found Floatx_np=np.sin(steps).reshape(1,TIME_STEP,INPUT_SIZE)y_np=np.cos(steps).reshape(1,TIME_STEP,1)#目标:用sin预测cosx=torch.from_numpy(x_np)y=torch.from_numpy(y_np)prediction,h_state=rnn(x,h_state)#每一组input,都对应了一个h_state和一个predictionh_state=h_state.data#将对应的h_state向后传loss=loss_func(prediction,y)optimizer.zero_grad()#清空上一步的参与更新参数值loss.backward()#误差反向传播,计算参数更新值optimizer.step()#将参数更新值施加到rnn的parameters上if(step % 10==0):plt.plot(steps,prediction.data.numpy().flatten(),'g*')plt.plot(steps,y_np.flatten(),'r-')plt.show()

6 实验结果

一开始

最终

7 整体函数

import torch
import numpy as np
import matplotlib.pyplot as pltTIME_STEP=10
INPUT_SIZE=1
HIDDEN_SIZE=32
LR=0.02class RNN(torch.nn.Module):def __init__(self):super(RNN,self).__init__()self.rnn=torch.nn.RNN(input_size=INPUT_SIZE,hidden_size=HIDDEN_SIZE,num_layers=1,batch_first=True)
#设置batch_first为True,那么输入数据的维度为(batch_size,time_step,input_size)
#如果不设置这个值,或者设置为False,那么输入数据的维度为(time_step,batch_size,input_size)self.out=torch.nn.Linear(HIDDEN_SIZE,1)#将隐藏层输出转化为需要的输出def forward(self,x,h_state):#因为在RNN中,下一个时间片隐藏层状态的计算需要上一个时间片的隐藏层状态,所以我们要一直传递这个h_state#x (batch_size,time_step,INPUT_SIZE)r_out,h_state=self.rnn(x,h_state)#h_state也要作为RNN的一个输入和一个输出#r_out:(batch_size,time_step,HIDDEN_SIZE)#h_state:(batch_size,time_step,HIDDEN_SIZE)outs=[]for time_step in range(r_out.size()[1]):outs.append(self.out(r_out[:,time_step,:]))#每一个要被self.out运算的元素[batch_size,1,HIDDEN_SIZE]#每个计算完,被append到outs的元素[batch_size,1,1]return torch.stack(outs,dim=1),h_state#返回的第一个元素[batch_size,time_step,1]#torch.stack函数的维度和axis不一样,dim=1的意思是在第一个维度处叠加rnn=RNN()
print(rnn)
'''
RNN((rnn): RNN(1, 32, batch_first=True)(out): Linear(in_features=32, out_features=1, bias=True)
)
'''optimizer=torch.optim.Adam(rnn.parameters(),lr=LR)
loss_func=torch.nn.MSELoss()h_state=Nonefor step in range(100):start=step*np.piend=(step+1)*np.pisteps=np.linspace(start,end,TIME_STEP,dtype=np.float32)#这里dtype这一部分一定要加,不然的话会报错,RuntimeError: expected scalar type Double but found Floatx_np=np.sin(steps).reshape(1,TIME_STEP,INPUT_SIZE)y_np=np.cos(steps).reshape(1,TIME_STEP,1)#目标:用sin预测cosx=torch.from_numpy(x_np)y=torch.from_numpy(y_np)prediction,h_state=rnn(x,h_state)#每一组input,都对应了一个h_state和一个predictionh_state=h_state.data#将对应的h_state向后传loss=loss_func(prediction,y)optimizer.zero_grad()#清空上一步的参与更新参数值loss.backward()#误差反向传播,计算参数更新值optimizer.step()#将参数更新值施加到rnn的parameters上if(step % 10==0):plt.plot(steps,prediction.data.numpy().flatten(),'g*')plt.plot(steps,y_np.flatten(),'r-')plt.show()

用pytorch实现简易RNN相关推荐

  1. (pytorch-深度学习)使用pytorch框架nn.RNN实现循环神经网络

    使用pytorch框架nn.RNN实现循环神经网络 首先,读取周杰伦专辑歌词数据集. import time import math import numpy as np import torch f ...

  2. pytorch中如何处理RNN输入变长序列padding

    一.为什么RNN需要处理变长输入 假设我们有情感分析的例子,对每句话进行一个感情级别的分类,主体流程大概是下图所示: 思路比较简单,但是当我们进行batch个训练数据一起计算的时候,我们会遇到多个训练 ...

  3. Pytorch实战——基于RNN的新闻分类

    目录 一.项目介绍 二.基于RNN的新闻分类 Step1 加载数据集 Step2 分词和构建词汇表 Step3 构建数据加载器 dataloader Step4 定义神经网络模型 Step5 定义模型 ...

  4. PyTorch入门笔记——RNN写诗(含藏头诗)程序代码学习笔记

    目录 注意事项 一.数据介绍 二.opt对象 三.data.py 四.☆model.py 五.utils.py封装可视化操作,略 六.main.py 了解`torch.utils.data.DataL ...

  5. 用 pytorch 实现 一个rnn

    原文链接 import torch torch.__version__class RNN(object):def __init__(self,input_size,hidden_size):super ...

  6. pytorch torch.nn.RNN

    应用 rnn = nn.RNN(10, 20, 2) input = torch.randn(5, 3, 10) h0 = torch.randn(2, 3, 20) output, hn = rnn ...

  7. pytorch实现简易分类模型

    1 导入库 import torch import matplotlib.pyplot as plt import torch.nn.functional as F 2 数据处理 n_data=tor ...

  8. pytorch笔记——简易回归问题

    1 数据集部分 #导入库 import torch import matplotlib.pyplot as plt#建议数据集 x=torch.linspace(-1,1,100) x=x.view( ...

  9. 对自注意力(self-attention)的理解以及基于pytorch的简易示例

    简介 自注意力(self-attention):一个seq2seq的映射运算,具体而言,也就是将输入向量通过映射(或者说某种函数运算)输出对应的结果. 向量的维度都为. 对于每个输出,就是用自注意力运 ...

最新文章

  1. 33关Python游戏,测试你的爬虫能力到底及格不?
  2. UA MATH523A 实分析3 积分理论例题 证明函数列L1收敛的一个题目
  3. 使用docker安装部署Spark集群来训练CNN(含Python实例)
  4. Scala 读取文件
  5. 【已解决】Error occurred during loading data. Trying to use cache server_Python系列学习笔记
  6. Vijos p1484 ISBN号码
  7. QT中的滚动条QScrollArea
  8. maven 项目上传私服pom配置
  9. Linux安装RabbitMQ及问题
  10. Caused by: java.lang.IllegalArgumentException: Result Maps collection does not contain value for...
  11. 转载 SpringMVC详解(三)------基于注解的入门实例
  12. 《剑指offer》面试题4——替换空格 C++编程
  13. linux内核类型lagency,使用u盘安装linux(manjaro)时Grub报错
  14. (最新版 易卷)自动阅卷系统|自动阅卷机|网上阅卷系统
  15. 大气压力换算公式_大气压强计算新方法
  16. 读后:水浒的水有多深
  17. 数学笔记25——弧长和曲面面积
  18. 奔富bef407价格_Penfolds Bin 707 Cabernet Sauvignon, South Australia, Australia
  19. 兔子与狐狸c语言,狐狸和兔子
  20. 6月 CSDN 创作者之夜:获奖名单公布

热门文章

  1. codeforces 665B Shopping
  2. SQL Tips:兼顾检索速度和精确性
  3. MSN工具条不兼容IE7
  4. Charpter 8:Declarative Middleware Using AOP:expert one-on-one J2EE Development without EJB.(读后感)...
  5. stm32f401 i2s 时序图
  6. Linux 802.11 Driver Developer’s Guide
  7. html5应用测试方法,详解html5的video标签测试应用
  8. linux如何自动清buff,centos7
  9. python django部署docker_如何Docker化Python Django应用程序
  10. python生成器杨辉三角_python 生成器生成杨辉三角的方法(必看)