上文我们利用pytorch构建了BP神经网络,LeNet,这次我们利用LSTM网络实现对MNIST数据集的分类,具体的数据获取方法本文不详细介绍,这里只要讲解搭建LSTM网络的方法以及参数设置。

这里我们只用一层LSTM网络+全连接层实现对模型的构建。

# 输入为图片 (batch, seq_len, feature) 照片的每一行看作一个特征,一个特征的长度为32
INPUT_SIZE=32
HIDDEN_SIZE=10
LAYERS=2
DROP_RATE=0.2
TIME_STEP = 32class LSTM(nn.Module):def __init__(self):super(LSTM, self).__init__()# 这里构建LSTM 还可以构建RNN、GRU等方法类似self.rnn = nn.LSTM(input_size=INPUT_SIZE,hidden_size=HIDDEN_SIZE,num_layers=LAYERS,dropout=DROP_RATE,batch_first=True  # 如果为True,输入输出数据格式是(batch, seq_len, feature)# 为False,输入输出数据格式是(seq_len, batch, feature),)self.hidden_out = nn.Linear(320, 10) #拼接隐藏层self.sig = nn.Sigmoid() #分类需要利用Sigmod激活函数def forward(self, x):r_out, (h_s, h_c)  = self.rnn(x)out = r_out.reshape(-1,320) # 这里隐藏层设置为10,故得到结果[-1,32,10]展开out = self.hidden_out(out) # 全连接层进行分类out = self.sig(out)return out

每一层的输出:

torch.Size([20, 32, 32])
torch.Size([20, 32, 10])
torch.Size([20, 320])
torch.Size([20, 10])

2、利用MNIST数据集训练模型

利用数据集训练模型,其余内容与之前内容相同,这里不在赘述,直接上代码。

import torch
from torchvision import datasets, transforms
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import numpy as npclass Config:batch_size = 128epoch = 10alpha = 1e-3print_per_step = 100  # 控制输出device = torch.device('cuda:0')INPUT_SIZE=32
HIDDEN_SIZE=10
LAYERS=2
DROP_RATE=0.2
TIME_STEP = 32class LSTM(nn.Module):def __init__(self):super(LSTM, self).__init__()self.rnn = nn.LSTM(input_size=INPUT_SIZE,hidden_size=HIDDEN_SIZE,num_layers=LAYERS,dropout=DROP_RATE,batch_first=True  # 如果为True,输入输出数据格式是(batch, seq_len, feature)# 为False,输入输出数据格式是(seq_len, batch, feature),)self.hidden_out = nn.Linear(320, 10)self.sig = nn.Sigmoid()def forward(self, x):r_out, (h_s, h_c)  = self.rnn(x)out = r_out.reshape(-1,320)out = self.hidden_out(out)out = self.sig(out)return outclass TrainProcess:def __init__(self):self.train, self.test = self.load_data()self.net = LSTM().to(device)self.criterion = nn.CrossEntropyLoss()  # 定义损失函数self.optimizer = optim.Adam(self.net.parameters(), lr=Config.alpha)@staticmethoddef load_data():train_data = datasets.MNIST(root='./data/',train=True,transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor()]),download=True)test_data = datasets.MNIST(root='./data/',train=False,transform=transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor()]))# 返回一个数据迭代器# shuffle:是否打乱顺序train_loader = torch.utils.data.DataLoader(dataset=train_data,batch_size=Config.batch_size,shuffle=True)test_loader = torch.utils.data.DataLoader(dataset=test_data,batch_size=Config.batch_size,shuffle=False)return train_loader, test_loaderdef train_step(self):print("Training & Evaluating based on LSTM......")file = 'result/train_mnist.txt'fp = open(file,'w',encoding='utf-8')fp.write('epoch\tbatch\tloss\taccuracy\n')for epoch in range(Config.epoch):print("Epoch {:3}.".format(epoch + 1))for batch_idx,(data,label) in enumerate(self.train):data, label = Variable(data.cuda()), Variable(label.cuda())data = data.squeeze(dim=1)self.optimizer.zero_grad()outputs = self.net(data)loss =self.criterion(outputs, label)loss.backward()self.optimizer.step()# 每100次打印一次结果if batch_idx % Config.print_per_step == 0:_, predicted = torch.max(outputs, 1)correct = 0for _ in predicted == label:if _:correct += 1accuracy = correct / Config.batch_sizemsg = "Batch: {:5}, Loss: {:6.2f}, Accuracy: {:8.2%}."print(msg.format(batch_idx, loss, accuracy))fp.write('{}\t{}\t{}\t{}\n'.format(epoch,batch_idx,loss,accuracy))fp.close()test_loss = 0.test_correct = 0for data, label in self.test:data, label = Variable(data.cuda()), Variable(label.cuda())data = data.squeeze(dim=1)outputs = self.net(data)loss = self.criterion(outputs, label)test_loss += loss * Config.batch_size_, predicted = torch.max(outputs, 1)correct = 0for _ in predicted == label:if _:correct += 1test_correct += correctaccuracy = test_correct / len(self.test.dataset)loss = test_loss / len(self.test.dataset)print("Test Loss: {:5.2f}, Accuracy: {:6.2%}".format(loss, accuracy))torch.save(self.net.state_dict(), './result/raw_train_mnist_model.pth')if __name__ == "__main__":p = TrainProcess()p.train_step()

有一点需要注意的是,我们模型的输入是[-1,32,32],但是数据集的shape是[-1,1,32,32],因此这里利用pytorch的方法压缩维度:

# 输入data.shape=[1,32,32] -> [32,32]
data = data.squeeze(dim=0)# 输入data.shape=[32,32] -> [1,32,32]
data = data.unsqueeze(dim=0)

基于pytorch的LSTM模型构建相关推荐

  1. 基于PyTorch的LSTM模型的IMBD情感分类遇到的问题

    今天想学LSTM的情感分类,结果碰到了一系列问题,耽误了很多时间.特此记录! 一.项目来源 lesson53-情感分类实战 B站视频 二.碰到的问题 1.报错AttributeError: modul ...

  2. 【金融】【pytorch】使用深度学习预测期货收盘价涨跌——LSTM模型构建与训练

    [金融][pytorch]使用深度学习预测期货收盘价涨跌--LSTM模型构建与训练 LSTM 创建模型 模型训练 查看指标 LSTM 创建模型 指标函数参考<如何用keras/tf/pytorc ...

  3. 【图像分类】基于PyTorch搭建LSTM实现MNIST手写数字体识别(双向LSTM,附完整代码和数据集)

    写在前面: 首先感谢兄弟们的关注和订阅,让我有创作的动力,在创作过程我会尽最大能力,保证作品的质量,如果有问题,可以私信我,让我们携手共进,共创辉煌. 在https://blog.csdn.net/A ...

  4. R语言基于自定义函数构建xgboost模型并使用LIME解释器进行模型预测结果解释:基于训练数据以及模型构建LIME解释器解释一个iris数据样本的预测结果、LIME解释器进行模型预测结果解释并可视化

    R语言基于自定义函数构建xgboost模型并使用LIME解释器进行模型预测结果解释:基于训练数据以及模型构建LIME解释器解释一个iris数据样本的预测结果.LIME解释器进行模型预测结果解释并可视化 ...

  5. R语言基于自定义函数构建xgboost模型并使用LIME解释器进行模型预测结果解释:基于训练数据以及模型构建LIME解释器解释多个iris数据样本的预测结果、使用LIME解释器进行模型预测结果解释

    R语言基于自定义函数构建xgboost模型并使用LIME解释器进行模型预测结果解释:基于训练数据以及模型构建LIME解释器解释多个iris数据样本的预测结果.使用LIME解释器进行模型预测结果解释并可 ...

  6. R使用LSTM模型构建深度学习文本分类模型(Quora Insincere Questions Classification)

    R使用LSTM模型构建深度学习文本分类模型(Quora Insincere Questions Classification) Long Short Term 网络-- 一般就叫做 LSTM --是一 ...

  7. RDKit | 基于scikit-learn将pytorch用于QSAR模型构建

    将PyTorch用于深度学习框架.PyTorch非常灵活,并且有很多文章将其用于实现. 利用scikit-learn一样调用fit来训练pytorch模型,使用skorch可以使训练过程变得简单. s ...

  8. 基于神经网络算法LSTM模型对股票指数进行预测

    资源下载地址:https://download.csdn.net/download/sheziqiong/86813208 资源下载地址:https://download.csdn.net/downl ...

  9. 为基于树的机器学习模型构建更好的建模数据集的10个小技巧!

    https://www.toutiao.com/a6680019995100971531/ 为了使模型更准确 - 只需对所有分类特征进行独热编码并将所有缺失值归零都可能是不够的. 假设有一个业务问题可 ...

最新文章

  1. 今晚直播 | 深入浅出理解A3C强化学习
  2. Java_bytecode
  3. C++中string查找和取子串和整形转化
  4. 为什么姜黄素+胡椒碱会让姜黄素吸收率增加2000%以上
  5. python 什么是原类_Python 什么是元类(metaclasses)?
  6. php网页脚本代码大全,PHP编写脚本代码的详细教程
  7. java和c++的区别大吗_莫桑钻和钻石外观区别大吗 莫桑钻和真的钻石有什么区别...
  8. 阿里云工程师用机器学习破解雾霾成因
  9. ios 类别(category)
  10. 计算机游戏软件制作,游戏制作软件,制作游戏的软件
  11. java wsimport 调用_java使用wsimport调用wcf接口
  12. XML解析——Java中XML的四种解析方式
  13. pandas按照多列排序-ascending
  14. 组态王与Modbus协议的地址对应规则
  15. win10软件拒绝访问删不掉_文件拒绝访问,详细教您win10文件访问被拒绝怎么解决...
  16. bat批处理删除日志文件
  17. 【C/C++】输入一个整数的二目运算式的字符串,如100+20,332-19,200*2333,44/33二目运算取”加减乘除“中的一种输出运算式的整数结果值
  18. Xilinx RFSOC GEN1 ADC和DAC简单测试
  19. PID闭环控制算法解析(最透彻)
  20. torch.addcdiv 和 torch.tensor.addcdiv_

热门文章

  1. 神微计算机世界排名,2018US News意大利大学学科排名
  2. 学习笔记:在WIN11及UBUNTU平台下的基于Tkinter、pydub、pyaudio的音乐播放器
  3. 我的2005,我的梦
  4. 【平面设计】Pro/E3.0 软件安装教程
  5. 云栖大会三大事,立成电商不谋而合深耕人才计划
  6. ​美国科技记者秘访Uber雇员:一代独角兽的衰落,谁来买单?
  7. 2012年度十大杰出IT博客之 罗升阳
  8. 【入门指导】C语言难吗?最难啃的三块硬骨头
  9. lgg7深度详细参数_混音笔记(十一)——混响器(2)混响器的参数
  10. 计算机用户使用年龄表,笔记本电脑用户年龄以21-30岁为主_联想 Y50-70AM-ISE_调研中心专项研究-中关村在线...