Pytorch+LSTM+Encoder+Decoder实现Seq2Seq模型
# !/usr/bin/env Python3 # -*- coding: utf-8 -*- # @version: v1.0 # @Author : Meng Li # @contact: 925762221@qq.com # @FILE : torch_seq2seq.py # @Time : 2022/6/8 11:11 # @Software : PyCharm # @site: # @Description : 将Seq2Seq网络采用编码器和解码器两个类进行融合 import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torchsummary from torch.utils.data import Dataset, DataLoader import numpy as np import osdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')class my_dataset(Dataset):def __init__(self, enc_input, dec_input, dec_output):super().__init__()self.enc_input = enc_inputself.dec_input = dec_inputself.dec_output = dec_outputdef __getitem__(self, index):return self.enc_input[index], self.dec_input[index], self.dec_output[index]def __len__(self):return self.enc_input.size(0)class Encoder(nn.Module):def __init__(self, in_features, hidden_size):super().__init__()self.in_features = in_featuresself.hidden_size = hidden_sizeself.encoder = nn.LSTM(input_size=in_features, hidden_size=hidden_size, dropout=0.5, num_layers=1) # encoderdef forward(self, enc_input):seq_len, batch_size, embedding_size = enc_input.size()h_0 = torch.rand(1, batch_size, self.hidden_size)c_0 = torch.rand(1, batch_size, self.hidden_size)# en_ht:[num_layers * num_directions,Batch_size,hidden_size]encode_output, (encode_ht, decode_ht) = self.encoder(enc_input, (h_0, c_0))return encode_output, (encode_ht, decode_ht)class Decoder(nn.Module):def __init__(self, in_features, hidden_size):super().__init__()self.in_features = in_featuresself.hidden_size = hidden_sizeself.crition = nn.CrossEntropyLoss()self.fc = nn.Linear(hidden_size, in_features)self.decoder = nn.LSTM(input_size=in_features, hidden_size=hidden_size, dropout=0.5, num_layers=1) # encoderdef forward(self, enc_output, dec_input):(h0, c0) = enc_output# en_ht:[num_layers * num_directions,Batch_size,hidden_size]de_output, (_, _) = self.decoder(dec_input, (h0, c0))return de_outputclass Seq2seq(nn.Module):def __init__(self, encoder, decoder, in_features, hidden_size):super().__init__()self.encoder = encoderself.decoder = decoderself.in_features = in_featuresself.hidden_size = hidden_sizeself.fc = nn.Linear(hidden_size, in_features)self.crition = nn.CrossEntropyLoss()def forward(self, enc_input, dec_input, dec_output):enc_input = enc_input.permute(1, 0, 2) # [seq_len,Batch_size,embedding_size]dec_input = dec_input.permute(1, 0, 2) # [seq_len,Batch_size,embedding_size]# output:[seq_len,Batch_size,hidden_size]_, (ht, ct) = self.encoder(enc_input) # en_ht:[num_layers * num_directions,Batch_size,hidden_size]de_output = self.decoder((ht, ct), dec_input) # de_output:[seq_len,Batch_size,in_features]output = self.fc(de_output)output = output.permute(1, 0, 2)loss = 0for i in range(len(output)): # 对seq的每一个输出进行二分类损失计算loss += self.crition(output[i], dec_output[i])return output, lossdef make_data(seq_data):enc_input_all, dec_input_all, dec_output_all = [], [], []vocab = [i for i in "SE?abcdefghijklmnopqrstuvwxyz上下人低国女孩王男白色高黑"]word2idx = {j: i for i, j in enumerate(vocab)}V = np.max([len(j) for i in seq_data for j in i]) # 求最长元素的长度for seq in seq_data:for i in range(2):seq[i] = seq[i] + '?' * (V - len(seq[i])) # 'man??', 'women'enc_input = [word2idx[n] for n in (seq[0] + 'E')]dec_input = [word2idx[i] for i in [i for i in len(enc_input) * '?']]dec_output = [word2idx[n] for n in (seq[1] + 'E')]enc_input_all.append(np.eye(len(vocab))[enc_input])dec_input_all.append(np.eye(len(vocab))[dec_input])dec_output_all.append(dec_output) # not one-hot# make tensorreturn torch.Tensor(enc_input_all), torch.Tensor(dec_input_all), torch.LongTensor(dec_output_all)def translate(word):vocab = [i for i in "SE?abcdefghijklmnopqrstuvwxyz上下人低国女孩王男白色高黑"]idx2word = {i: j for i, j in enumerate(vocab)}V = 5x, y, z = make_data([[word, "?" * V]])if not os.path.exists("translate.pt"):train()net = torch.load("translate.pt")pre, loss = net(x, y, z)pre = torch.argmax(pre, 2)[0]pre_word = [idx2word[i] for i in pre.numpy()]pre_word = "".join([i.replace("?", "") for i in pre_word])print(word, "-> ", pre_word[:pre_word.index('E')])def train():vocab = [i for i in "SE?abcdefghijklmnopqrstuvwxyz上下人低国女孩王男白色高黑"]word2idx = {j: i for i, j in enumerate(vocab)}idx2word = {i: j for i, j in enumerate(vocab)}seq_data = [['man', '男人'], ['black', '黑色'], ['king', '国王'], ['girl', '女孩'], ['up', '上'],['high', '高'], ['women', '女人'], ['white', '白色'], ['boy', '男孩'], ['down', '下'], ['low', '低'],['queen', '女王']]enc_input, dec_input, dec_output = make_data(seq_data)batch_size = 3in_features = len(vocab)hidden_size = 128train_data = my_dataset(enc_input, dec_input, dec_output)train_iter = DataLoader(train_data, batch_size, shuffle=True)encoder = Encoder(in_features, hidden_size)decoder = Decoder(in_features, hidden_size)net = Seq2seq(encoder, decoder, in_features, hidden_size)learning_rate = 0.001optimizer = optim.Adam(net.parameters(), lr=learning_rate)loss = 0for i in range(1000):for en_input, de_input, de_output in train_iter:output, loss = net(en_input, de_input, de_output)pre = torch.argmax(output, 2)optimizer.zero_grad()loss.backward()optimizer.step()if i % 100 == 0:print("step {0} loss {1}".format(i, loss))torch.save(net, "translate.pt")if __name__ == '__main__':before_test = ['man', 'black', 'king', 'girl', 'up', 'high', 'women', 'white', 'boy', 'down', 'low', 'queen','mman', 'woman'][translate(i) for i in before_test]# train()
仍然先上代码,接上一篇文章,这里将Seq2Seq模型个构建采用Encoder类和Decoder类融合起来
主要是为了后面的Attention作铺垫
Pytorch+LSTM+Encoder+Decoder实现Seq2Seq模型相关推荐
- 从Encoder到Decoder实现Seq2Seq模型
首发于机器不学习 关注专栏 写文章 从Encoder到Decoder实现Seq2Seq模型 天雨粟 模型师傅 / 果粉 关注他 300 人赞同了该文章 更新:感谢@Gang He指出的代码错误.g ...
- Seq2Seq模型及Attention机制
Seq2Seq模型及Attention机制 Seq2Seq模型 Encoder部分 Decoder部分 seq2seq模型举例 LSTM简单介绍 基于CNN的seq2seq Transformer A ...
- PyTorch学习(7)-Seq2Seq与 Attention
Seq2Seq与 Attention import os import sys import math from collections import Counter import numpy as ...
- 【PyTorch】11 聊天机器人实战——Cornell Movie-Dialogs Corpus电影剧本数据集处理、利用Global attention实现Seq2Seq模型
聊天机器人教程 1. 下载数据文件 2. 加载和预处理数据 2.1 创建格式化数据文件 2.2 加载和清洗数据 3.为模型准备数据 4.定义模型 4.1 Seq2Seq模型 4.2 编码器 4.3 解 ...
- Seq2Seq模型PyTorch版本
Seq2Seq模型介绍以及Pytorch版本代码详解 一.Seq2Seq模型的概述 Seq2Seq是一种循环神经网络的变种,是一种端到端的模型,包括 Encoder编码器和 Decoder解码器部分, ...
- Seq2Seq模型实现(Decoder部分)
0.引言: 承接上一篇,现在继续对于seq2seq模型进行讲解,decoder部分是和encoder部分对应的,层数.隐藏层.单元数都要对应. 1.LSTM Seq2Seq Decoder Decod ...
- 深度学习的seq2seq模型——本质是LSTM,训练过程是使得所有样本的p(y1,...,yT‘|x1,...,xT)概率之和最大...
from:https://baijiahao.baidu.com/s?id=1584177164196579663&wfr=spider&for=pc seq2seq模型是以编码(En ...
- Pytorch+LSTM+Attention 实现 Seq2Seq
# !/usr/bin/env Python3 # -*- coding: utf-8 -*- # @version: v1.0 # @Author : Meng Li # @contact: 925 ...
- encoder decoder 模型理解
encoder decoder 模型是比较难理解的,理解这个模型需要清楚lstm 的整个源码细节,坦率的说这个模型我看了近十天,不敢说完全明白. 我把细胞的有丝分裂的图片放在开头,我的直觉细胞的有丝分 ...
- pytorch seq2seq模型中加入teacher_forcing机制
在循环内加的teacher forcing机制,这种为目标确定的时候,可以这样加. 目标不确定,需要在循环外加. decoder.py 中的修改 """ 实现解码器 &q ...
最新文章
- 易语言静态连接器提取_易语言静态编译链接器切换工具
- SpringMVC 如何实现将消息的websocket
- archlinux 安装 Windows 字体
- 如何避免jquery库和其它库的冲突
- Jmeter + Grafana + InfluxDB 性能测试监控
- h3c交换机 查看二层交换机端口ip_【分享】项目中如何选到称心如意的交换机?...
- ghost后自动修改IP和计算机名的VBS脚本
- 走好达叔!每年“癌症”新增400万例,数据分析揭示“癌症”到底有多可怕
- C#实现把科学计数法(E)转化为正常数字值 (转)
- 在input标签里只能输入数字
- PyQT5 QtWidgets 设置单元格不可编辑/可编辑 恢复单元格默认设置
- delphi 连接网口打印机 发送指令打印二维码
- 树莓派交叉编译USB转网卡驱动_incomplete
- 铃木雅臣晶体管电路设计学习笔记1
- Intent跳转页面大全
- 13-4Happy Mid-Autumn Festival
- 金钱和私有制哪个才是万恶之源?
- Centos逻辑卷扩容、合并
- Python--变量
- 国家电力项目思路总结
热门文章
- 关于错误“未能加载文件或程序集”的错误的若干处理办法——对GAC的简单应用
- html期末大作业~自制崩坏3网站(附原码)
- 企业为什么选择软件定制开发?
- 陈省身文集51——闭黎曼流形高斯-博内公式的一个简单的内蕴证明
- 资源、角色、用户、岗位的关系(工作中用到的)
- 谷哥学术2022年2月资源分享下载列表 15/20
- HyperLedger Fabric 查询机制
- RFID固定资产盘点的解决方案
- 17_AOP入门准备_Salay案例(利用动态代理)
- 《P2SGrad Refined Gradients for Optimizing Deep Face Models》论文阅读