# !/usr/bin/env Python3
# -*- coding: utf-8 -*-
# @version: v1.0
# @Author   : Meng Li
# @contact: 925762221@qq.com
# @FILE     : torch_seq2seq.py
# @Time     : 2022/6/8 11:11
# @Software : PyCharm
# @site:
# @Description : 将Seq2Seq网络采用编码器和解码器两个类进行融合
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchsummary
from torch.utils.data import Dataset, DataLoader
import numpy as np
import osdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')class my_dataset(Dataset):def __init__(self, enc_input, dec_input, dec_output):super().__init__()self.enc_input = enc_inputself.dec_input = dec_inputself.dec_output = dec_outputdef __getitem__(self, index):return self.enc_input[index], self.dec_input[index], self.dec_output[index]def __len__(self):return self.enc_input.size(0)class Encoder(nn.Module):def __init__(self, in_features, hidden_size):super().__init__()self.in_features = in_featuresself.hidden_size = hidden_sizeself.encoder = nn.LSTM(input_size=in_features, hidden_size=hidden_size, dropout=0.5, num_layers=1)  # encoderdef forward(self, enc_input):seq_len, batch_size, embedding_size = enc_input.size()h_0 = torch.rand(1, batch_size, self.hidden_size)c_0 = torch.rand(1, batch_size, self.hidden_size)# en_ht:[num_layers * num_directions,Batch_size,hidden_size]encode_output, (encode_ht, decode_ht) = self.encoder(enc_input, (h_0, c_0))return encode_output, (encode_ht, decode_ht)class Decoder(nn.Module):def __init__(self, in_features, hidden_size):super().__init__()self.in_features = in_featuresself.hidden_size = hidden_sizeself.crition = nn.CrossEntropyLoss()self.fc = nn.Linear(hidden_size, in_features)self.decoder = nn.LSTM(input_size=in_features, hidden_size=hidden_size, dropout=0.5, num_layers=1)  # encoderdef forward(self, enc_output, dec_input):(h0, c0) = enc_output# en_ht:[num_layers * num_directions,Batch_size,hidden_size]de_output, (_, _) = self.decoder(dec_input, (h0, c0))return de_outputclass Seq2seq(nn.Module):def __init__(self, encoder, decoder, in_features, hidden_size):super().__init__()self.encoder = encoderself.decoder = decoderself.in_features = in_featuresself.hidden_size = hidden_sizeself.fc = nn.Linear(hidden_size, in_features)self.crition = nn.CrossEntropyLoss()def forward(self, enc_input, dec_input, dec_output):enc_input = enc_input.permute(1, 0, 2)  # [seq_len,Batch_size,embedding_size]dec_input = dec_input.permute(1, 0, 2)  # [seq_len,Batch_size,embedding_size]# output:[seq_len,Batch_size,hidden_size]_, (ht, ct) = self.encoder(enc_input)  # en_ht:[num_layers * num_directions,Batch_size,hidden_size]de_output = self.decoder((ht, ct), dec_input)  # de_output:[seq_len,Batch_size,in_features]output = self.fc(de_output)output = output.permute(1, 0, 2)loss = 0for i in range(len(output)):  # 对seq的每一个输出进行二分类损失计算loss += self.crition(output[i], dec_output[i])return output, lossdef make_data(seq_data):enc_input_all, dec_input_all, dec_output_all = [], [], []vocab = [i for i in "SE?abcdefghijklmnopqrstuvwxyz上下人低国女孩王男白色高黑"]word2idx = {j: i for i, j in enumerate(vocab)}V = np.max([len(j) for i in seq_data for j in i])  # 求最长元素的长度for seq in seq_data:for i in range(2):seq[i] = seq[i] + '?' * (V - len(seq[i]))  # 'man??', 'women'enc_input = [word2idx[n] for n in (seq[0] + 'E')]dec_input = [word2idx[i] for i in [i for i in len(enc_input) * '?']]dec_output = [word2idx[n] for n in (seq[1] + 'E')]enc_input_all.append(np.eye(len(vocab))[enc_input])dec_input_all.append(np.eye(len(vocab))[dec_input])dec_output_all.append(dec_output)  # not one-hot# make tensorreturn torch.Tensor(enc_input_all), torch.Tensor(dec_input_all), torch.LongTensor(dec_output_all)def translate(word):vocab = [i for i in "SE?abcdefghijklmnopqrstuvwxyz上下人低国女孩王男白色高黑"]idx2word = {i: j for i, j in enumerate(vocab)}V = 5x, y, z = make_data([[word, "?" * V]])if not os.path.exists("translate.pt"):train()net = torch.load("translate.pt")pre, loss = net(x, y, z)pre = torch.argmax(pre, 2)[0]pre_word = [idx2word[i] for i in pre.numpy()]pre_word = "".join([i.replace("?", "") for i in pre_word])print(word, "->  ", pre_word[:pre_word.index('E')])def train():vocab = [i for i in "SE?abcdefghijklmnopqrstuvwxyz上下人低国女孩王男白色高黑"]word2idx = {j: i for i, j in enumerate(vocab)}idx2word = {i: j for i, j in enumerate(vocab)}seq_data = [['man', '男人'], ['black', '黑色'], ['king', '国王'], ['girl', '女孩'], ['up', '上'],['high', '高'], ['women', '女人'], ['white', '白色'], ['boy', '男孩'], ['down', '下'], ['low', '低'],['queen', '女王']]enc_input, dec_input, dec_output = make_data(seq_data)batch_size = 3in_features = len(vocab)hidden_size = 128train_data = my_dataset(enc_input, dec_input, dec_output)train_iter = DataLoader(train_data, batch_size, shuffle=True)encoder = Encoder(in_features, hidden_size)decoder = Decoder(in_features, hidden_size)net = Seq2seq(encoder, decoder, in_features, hidden_size)learning_rate = 0.001optimizer = optim.Adam(net.parameters(), lr=learning_rate)loss = 0for i in range(1000):for en_input, de_input, de_output in train_iter:output, loss = net(en_input, de_input, de_output)pre = torch.argmax(output, 2)optimizer.zero_grad()loss.backward()optimizer.step()if i % 100 == 0:print("step {0} loss {1}".format(i, loss))torch.save(net, "translate.pt")if __name__ == '__main__':before_test = ['man', 'black', 'king', 'girl', 'up', 'high', 'women', 'white', 'boy', 'down', 'low', 'queen','mman', 'woman'][translate(i) for i in before_test]# train()

仍然先上代码,接上一篇文章,这里将Seq2Seq模型个构建采用Encoder类和Decoder类融合起来

主要是为了后面的Attention作铺垫

Pytorch+LSTM+Encoder+Decoder实现Seq2Seq模型相关推荐

  1. 从Encoder到Decoder实现Seq2Seq模型

    首发于机器不学习 关注专栏 写文章 从Encoder到Decoder实现Seq2Seq模型 天雨粟 模型师傅 / 果粉 ​ 关注他 300 人赞同了该文章 更新:感谢@Gang He指出的代码错误.g ...

  2. Seq2Seq模型及Attention机制

    Seq2Seq模型及Attention机制 Seq2Seq模型 Encoder部分 Decoder部分 seq2seq模型举例 LSTM简单介绍 基于CNN的seq2seq Transformer A ...

  3. PyTorch学习(7)-Seq2Seq与 Attention

    Seq2Seq与 Attention import os import sys import math from collections import Counter import numpy as ...

  4. 【PyTorch】11 聊天机器人实战——Cornell Movie-Dialogs Corpus电影剧本数据集处理、利用Global attention实现Seq2Seq模型

    聊天机器人教程 1. 下载数据文件 2. 加载和预处理数据 2.1 创建格式化数据文件 2.2 加载和清洗数据 3.为模型准备数据 4.定义模型 4.1 Seq2Seq模型 4.2 编码器 4.3 解 ...

  5. Seq2Seq模型PyTorch版本

    Seq2Seq模型介绍以及Pytorch版本代码详解 一.Seq2Seq模型的概述 Seq2Seq是一种循环神经网络的变种,是一种端到端的模型,包括 Encoder编码器和 Decoder解码器部分, ...

  6. Seq2Seq模型实现(Decoder部分)

    0.引言: 承接上一篇,现在继续对于seq2seq模型进行讲解,decoder部分是和encoder部分对应的,层数.隐藏层.单元数都要对应. 1.LSTM Seq2Seq Decoder Decod ...

  7. 深度学习的seq2seq模型——本质是LSTM,训练过程是使得所有样本的p(y1,...,yT‘|x1,...,xT)概率之和最大...

    from:https://baijiahao.baidu.com/s?id=1584177164196579663&wfr=spider&for=pc seq2seq模型是以编码(En ...

  8. Pytorch+LSTM+Attention 实现 Seq2Seq

    # !/usr/bin/env Python3 # -*- coding: utf-8 -*- # @version: v1.0 # @Author : Meng Li # @contact: 925 ...

  9. encoder decoder 模型理解

    encoder decoder 模型是比较难理解的,理解这个模型需要清楚lstm 的整个源码细节,坦率的说这个模型我看了近十天,不敢说完全明白. 我把细胞的有丝分裂的图片放在开头,我的直觉细胞的有丝分 ...

  10. pytorch seq2seq模型中加入teacher_forcing机制

    在循环内加的teacher forcing机制,这种为目标确定的时候,可以这样加. 目标不确定,需要在循环外加. decoder.py 中的修改 """ 实现解码器 &q ...

最新文章

  1. 易语言静态连接器提取_易语言静态编译链接器切换工具
  2. SpringMVC 如何实现将消息的websocket
  3. archlinux 安装 Windows 字体
  4. 如何避免jquery库和其它库的冲突
  5. Jmeter + Grafana + InfluxDB 性能测试监控
  6. h3c交换机 查看二层交换机端口ip_【分享】项目中如何选到称心如意的交换机?...
  7. ghost后自动修改IP和计算机名的VBS脚本
  8. 走好达叔!每年“癌症”新增400万例,数据分析揭示“癌症”到底有多可怕
  9. C#实现把科学计数法(E)转化为正常数字值 (转)
  10. 在input标签里只能输入数字
  11. PyQT5 QtWidgets 设置单元格不可编辑/可编辑 恢复单元格默认设置
  12. delphi 连接网口打印机 发送指令打印二维码
  13. 树莓派交叉编译USB转网卡驱动_incomplete
  14. 铃木雅臣晶体管电路设计学习笔记1
  15. Intent跳转页面大全
  16. 13-4Happy Mid-Autumn Festival
  17. 金钱和私有制哪个才是万恶之源?
  18. Centos逻辑卷扩容、合并
  19. Python--变量
  20. 国家电力项目思路总结

热门文章

  1. 关于错误“未能加载文件或程序集”的错误的若干处理办法——对GAC的简单应用
  2. html期末大作业~自制崩坏3网站(附原码)
  3. 企业为什么选择软件定制开发?
  4. 陈省身文集51——闭黎曼流形高斯-博内公式的一个简单的内蕴证明
  5. 资源、角色、用户、岗位的关系(工作中用到的)
  6. 谷哥学术2022年2月资源分享下载列表 15/20
  7. HyperLedger Fabric 查询机制
  8. RFID固定资产盘点的解决方案
  9. 17_AOP入门准备_Salay案例(利用动态代理)
  10. 《P2SGrad Refined Gradients for Optimizing Deep Face Models》论文阅读