文章目录

  • 精简代码
  • 模型可视化
    • 训练器
    • 编码器(从属于训练器)
    • 解码器
  • 训练结果展示
  • 附录

精简代码

本文参考官网栗子,汉化注释并精简100行代码(直接复制可用)

"""
《长恨歌》——白居易
汉皇重色思倾国,御宇多年求不得。
杨家有女初长成,养在深闺人未识。
天生丽质难自弃,一朝选在君王侧。
回眸一笑百媚生,六宫粉黛无颜色。
春寒赐浴华清池,温泉水滑洗凝脂。
侍儿扶起娇无力,始是新承恩泽时。
云鬓花颜金步摇,芙蓉帐暖度春宵。
春宵苦短日高起,从此君王不早朝。
承欢侍宴无闲暇,春从春游夜专夜。
后宫佳丽三千人,三千宠爱在一身。
金屋妆成娇侍夜,玉楼宴罢醉和春。
姊妹弟兄皆列土,可怜光彩生门户。
遂令天下父母心,不重生男重生女。
骊宫高处入青云,仙乐风飘处处闻。
缓歌慢舞凝丝竹,尽日君王看不足。
渔阳鼙鼓动地来,惊破霓裳羽衣曲。
九重城阙烟尘生,千乘万骑西南行。
翠华摇摇行复止,西出都门百余里。
六军不发无奈何,宛转蛾眉马前死。
花钿委地无人收,翠翘金雀玉搔头。
君王掩面救不得,回看血泪相和流。
黄埃散漫风萧索,云栈萦纡登剑阁。
峨嵋山下少人行,旌旗无光日色薄。
蜀江水碧蜀山青,圣主朝朝暮暮情。
行宫见月伤心色,夜雨闻铃肠断声。
天旋地转回龙驭,到此踌躇不能去。
马嵬坡下泥土中,不见玉颜空死处。
君臣相顾尽沾衣,东望都门信马归。
归来池苑皆依旧,太液芙蓉未央柳。
芙蓉如面柳如眉,对此如何不泪垂。
春风桃李花开日,秋雨梧桐叶落时。
西宫南内多秋草,落叶满阶红不扫。
梨园弟子白发新,椒房阿监青娥老。
夕殿萤飞思悄然,孤灯挑尽未成眠。
迟迟钟鼓初长夜,耿耿星河欲曙天。
鸳鸯瓦冷霜华重,翡翠衾寒谁与共。
悠悠生死别经年,魂魄不曾来入梦。
临邛道士鸿都客,能以精诚致魂魄。
为感君王辗转思,遂教方士殷勤觅。
排空驭气奔如电,升天入地求之遍。
上穷碧落下黄泉,两处茫茫皆不见。
忽闻海上有仙山,山在虚无缥渺间。
楼阁玲珑五云起,其中绰约多仙子。
中有一人字太真,雪肤花貌参差是。
金阙西厢叩玉扃,转教小玉报双成。
闻道汉家天子使,九华帐里梦魂惊。
揽衣推枕起徘徊,珠箔银屏迤逦开。
云鬓半偏新睡觉,花冠不整下堂来。
风吹仙袂飘飖举,犹似霓裳羽衣舞。
玉容寂寞泪阑干,梨花一枝春带雨。
含情凝睇谢君王,一别音容两渺茫。
昭阳殿里恩爱绝,蓬莱宫中日月长。
回头下望人寰处,不见长安见尘雾。
惟将旧物表深情,钿合金钗寄将去。
钗留一股合一扇,钗擘黄金合分钿。
但教心似金钿坚,天上人间会相见。
临别殷勤重寄词,词中有誓两心知。
七月七日长生殿,夜半无人私语时。
在天愿作比翼鸟,在地愿为连理枝。
天长地久有时尽,此恨绵绵无绝期。
"""import numpy as np
from keras.utils import to_categorical, plot_model
from keras.models import Model
from keras.layers import Dense, Input, LSTM"""配置"""
units = 200  # LSTM神经元数量
len_input = 7  # 输入序列长度(输出序列长度=16-7=9)
epochs = 2000"""语料加载"""
seqs = __doc__.replace('《长恨歌》——白居易', '').strip().split('\n')"""构建序列和字库"""
seqs_input, seqs_output = [], []  # 输入、输出序列
chr_set_input, chr_set_output = set(), set()  # 字库
for seq in seqs:inputs, outputs = seq[:len_input], seq[len_input:]seqs_input.append(inputs)seqs_output.append(outputs)chr_set_input |= set(inputs)chr_set_output |= set(outputs)num_classes_input = len(chr_set_input)
num_classes_output = len(chr_set_output)
print('字库量(输入)', num_classes_input, '字库量(输出)', num_classes_output)"""构建字符和索引间的映射"""
chr2id_input = {c: i for i, c in enumerate(chr_set_input)}
chr2id_output = {c: i for i, c in enumerate(chr_set_output)}
id2chr_output = {i: c for c, i in chr2id_output.items()}
id_start = chr2id_output[',']  # 起点ID
id_end = chr2id_output['。']  # 终点ID"""构建输入层和输出层"""
x_encoder = [[chr2id_input[c] for c in chrs] for chrs in seqs_input]
x_decoder = [[chr2id_output[c] for c in chrs[:-1]] for chrs in seqs_output]  # 起点+序列
y = [[chr2id_output[c] for c in chrs[1:]] for chrs in seqs_output]  # 序列+终点x_encoder = to_categorical(x_encoder, num_classes=num_classes_input)
x_decoder = to_categorical(x_decoder, num_classes=num_classes_output)
y = to_categorical(y, num_classes=num_classes_output)
print('输入维度', x_encoder.shape, x_decoder.shape, '输出维度', y.shape)"""创建联合模型"""
encoder_input = Input(shape=(None, num_classes_input))  # 编码器输入层
encoder_lstm = LSTM(units, return_state=True)  # 编码器LSTM层
_, encoder_h, encoder_c = encoder_lstm(encoder_input)  # 编码器LSTM输出
model_encoder = Model(encoder_input, [encoder_h, encoder_c])  # 【编码器模型】decoder_input = Input(shape=(None, num_classes_output))  # 解码器输入层
decoder_lstm = LSTM(units, return_sequences=True, return_state=True)  # 解码器LSTM层
decoder_lstm_output, _, _ = decoder_lstm(decoder_input, initial_state=[encoder_h, encoder_c])  # 解码器LSTM输出
decoder_softmax = Dense(num_classes_output, activation='softmax')  # 解码器softmax层
decoder_output = decoder_softmax(decoder_lstm_output)  # 解码器输出model = Model([encoder_input, decoder_input], decoder_output)  # 【联合模型】
model.compile('adam', 'categorical_crossentropy')
model.fit([x_encoder, x_decoder], y, epochs=epochs, verbose=2)"""创建解码模型"""
decoder_h_input = Input(shape=(units,))  # 解码器状态输入层h
decoder_c_input = Input(shape=(units,))  # 解码器状态输入层c
decoder_lstm_output, decoder_h, decoder_c = decoder_lstm(decoder_input, initial_state=[decoder_h_input, decoder_c_input])  # 解码器LSTM输出
decoder_output = decoder_softmax(decoder_lstm_output)  # 解码器输出
model_decoder = Model([decoder_input, decoder_h_input, decoder_c_input],[decoder_output, decoder_h, decoder_c])  # 【解码器模型】"""模型可视化"""
plot_model(model, show_shapes=True, show_layer_names=False)
plot_model(model_encoder, 'encoder.png', show_shapes=True, show_layer_names=False)
plot_model(model_decoder, 'decoder.png', show_shapes=True, show_layer_names=False)"""序列生成序列"""
def seq2seq(x_encoder_pred):h, c = model_encoder.predict(x_encoder_pred)id_pred = id_startseq = ''while id_pred != id_end:y_pred = to_categorical([[[id_pred]]], num_classes_output)output, h, c = model_decoder.predict([y_pred, h, c])id_pred = np.argmax(output[0])seq += id2chr_output[id_pred]return seq[:-1]"""模型评估"""
for i in range(len(seqs)):seq = seq2seq(x_encoder[i: i + 1])print('原输入输出:%s' % seqs[i])print('模型输出:%s\n' % seq)while True:chrs = input('输入:').strip()  # 输入7个字x_encoder_pred = to_categorical([[chr2id_input[c] for c in chrs]], num_classes_input)seq = seq2seq(x_encoder_pred)print('输出:%s\n' % seq)

模型可视化

训练器

编码器(从属于训练器)

解码器

训练结果展示


附录

另附【字符级英译中seq2seq】代码和语料下载地址

import numpy as np, os
from collections import Counter
from keras.preprocessing.sequence import pad_sequences
from keras.utils import to_categorical, plot_model
from keras.models import Model, load_model
from keras.layers import Dense, Input, LSTM"""配置"""
corpus_path = 'en2cn.txt'
num_classes_input = 32 + 1  # 英文低频字过滤
num_classes_output = 1000 + 3  # 中文低频字过滤
chr_pad = ''  # 填充字符
chr_start = '['  # 起始字符
chr_end = ']'  # 结束字符
id_pad = 0  # 填充字ID
id_start = 1  # 起点ID
id_end = 2  # 终点IDunits = 400  # LSTM神经元数量
batchsize = 512
epochs = 1000prefix = 'model/'  # 保存模型的文件夹
path_hdf5 = prefix + 'model.hdf5'
path_hdf5_encoder = prefix + 'encoder.hdf5'
path_hdf5_decoder = prefix + 'decoder.hdf5'
path_png = prefix + 'model.png'
path_png_encoder = prefix + 'encoder.png'
path_png_decoder = prefix + 'decoder.png'def preprocess_data():"""语料加载"""with open(corpus_path, encoding='utf-8') as f:seqs = f.read().lower().split('\n')"""构建序列和字库"""seqs_input, seqs_output = [], []  # 输入、输出序列counter_input, counter_output = Counter(), Counter()  # 字库for seq in seqs:inputs, outputs = seq.split('\t')counter_input += Counter(list(inputs))counter_output += Counter(list(outputs))outputs = chr_start + outputs + chr_end  # 加入起终点seqs_input.append(inputs)seqs_output.append(outputs)# 过滤低频词counter_input = counter_input.most_common(num_classes_input - 1)counter_output = counter_output.most_common(num_classes_output - 3)# 加入字符(填充、起点、终点)到字库counter_input = [chr_pad] + [i[0] for i in counter_input]counter_output = [chr_pad, chr_start, chr_end] + [i[0] for i in counter_output]"""字符和索引间的映射"""chr2id_input = {c: i for i, c in enumerate(counter_input)}chr2id_output = {c: i for i, c in enumerate(counter_output)}c2i_input = lambda c: chr2id_input.get(c, 0)c2i_output = lambda c: chr2id_output.get(c, 0)id2chr_output = {i: c for c, i in chr2id_output.items()}yield c2i_input, c2i_output, id2chr_output"""输入层和输出层"""# 输入序列x_encoder = [[c2i_input(c) for c in chrs if c2i_input(c)] for chrs in seqs_input]# 起点 + 输出序列x_decoder = [[c2i_output(c) for c in chrs[:-1] if c2i_output(c)] for chrs in seqs_output]# 输出序列 + 终点y = [[c2i_output(c) for c in chrs[1:] if c2i_output(c)] for chrs in seqs_output]# 输入输出序列最大长度maxlen_input = max(len(i) for i in x_encoder)maxlen_output = max(len(i) for i in x_decoder)yield maxlen_input, maxlen_output# 序列截断或补齐为等长x_encoder = pad_sequences(x_encoder, maxlen_input, padding='post', truncating='post')x_decoder = pad_sequences(x_decoder, maxlen_output, padding='post', truncating='post')y = pad_sequences(y, maxlen_output, padding='post', truncating='post')# 独热码x_encoder = to_categorical(x_encoder, num_classes=num_classes_input)x_decoder = to_categorical(x_decoder, num_classes=num_classes_output)y = to_categorical(y, num_classes=num_classes_output)print('输入维度', x_encoder.shape, x_decoder.shape, '输出维度', y.shape)yield x_encoder, x_decoder, y[(c2i_input, c2i_output, id2chr_output),(maxlen_input, maxlen_output),(x_encoder, x_decoder, y)] = list(preprocess_data())if os.path.exists(prefix):"""加载已训练模型"""model = load_model(path_hdf5)model_encoder = load_model(path_hdf5_encoder)model_decoder = load_model(path_hdf5_decoder)
else:"""编码模型"""encoder_input = Input(shape=(None, num_classes_input))  # 编码器输入层encoder_lstm = LSTM(units, return_state=True)  # 编码器LSTM层_, encoder_h, encoder_c = encoder_lstm(encoder_input)  # 编码器LSTM输出model_encoder = Model(encoder_input, [encoder_h, encoder_c])  # 【编码模型】# 解码器decoder_input = Input(shape=(None, num_classes_output))  # 解码器输入层decoder_lstm = LSTM(units, return_sequences=True, return_state=True)  # 解码器LSTM层decoder_output, _, _ = decoder_lstm(decoder_input, initial_state=[encoder_h, encoder_c])  # 解码器LSTM输出decoder_dense = Dense(num_classes_output, activation='softmax')  # 解码器softmax层decoder_output = decoder_dense(decoder_output)  # 解码器输出"""训练模型"""model = Model([encoder_input, decoder_input], decoder_output)  # 【训练模型】model.compile('adam', 'categorical_crossentropy')model.fit([x_encoder, x_decoder], y, batchsize, epochs, verbose=2)"""解码模型"""decoder_h_input = Input(shape=(units,))  # 解码器状态输入层hdecoder_c_input = Input(shape=(units,))  # 解码器状态输入层cdecoder_output, decoder_h, decoder_c = decoder_lstm(decoder_input, initial_state=[decoder_h_input, decoder_c_input])  # 解码器LSTM输出decoder_output = decoder_dense(decoder_output)  # 解码器输出model_decoder = Model([decoder_input, decoder_h_input, decoder_c_input],[decoder_output, decoder_h, decoder_c])  # 【解码模型】# 模型保存os.mkdir(prefix)plot_model(model, path_png, show_shapes=True, show_layer_names=False)plot_model(model_encoder, path_png_encoder, show_shapes=True, show_layer_names=False)plot_model(model_decoder, path_png_decoder, show_shapes=True, show_layer_names=False)model.save(path_hdf5)model_encoder.save(path_hdf5_encoder)model_decoder.save(path_hdf5_decoder)"""序列生成序列"""
def seq2seq(x_encoder_pred):h, c = model_encoder.predict(x_encoder_pred)id_pred = id_startseq = ''for _ in range(maxlen_output):y_pred = to_categorical([[[id_pred]]], num_classes=num_classes_output)output, h, c = model_decoder.predict([y_pred, h, c])id_pred = np.argmax(output[0])seq += id2chr_output[id_pred]if id_pred == id_end:breakreturn seq[:-1]while True:chrs = input('输入:').strip().lower()x_encoder_pred = [[c2i_input(c) for c in chrs]]x_encoder_pred = pad_sequences(x_encoder_pred, maxlen_input, padding='post', truncating='post')x_encoder_pred = to_categorical(x_encoder_pred, num_classes_input)seq = seq2seq(x_encoder_pred)print('输出:%s\n' % seq)

Keras【极简】seq2seq相关推荐

  1. Keras【极简】ACGAN

    文章目录 1.序言 2.网络结构 2.1.生 成 器 2.2.审 判 者 2.3.欺 诈 者 3.代码(直接复制可用) 4.伪 造 图 像 展 示 1.序言 GAN升级版:辅助分类器对抗生成式网络 ( ...

  2. TF-Lite极简参考-模型转换

    TF-Lite极简参考-模型转换 <TF-Lite极简参考-模型转换>   TensorFlow Lite 可以很方便的把基于TensorFlow训练的模型进行转换,然后推理,在Tenso ...

  3. RepVGG:极简架构,SOTA性能,论文解读

    ** RepVGG:极简架构,SOTA性能,论文解读 ** 更新:RepVGG的更深版本达到了83.55%正确率!PyTorch代码和模型已经在GitHub上放出.DingXiaoH/RepVGG 2 ...

  4. 业务逻辑组件化android,AppJoint 极简 Android 组件化方案

    AppJoint 极简 Android 组件化方案.仅包含 3 个注解加 1 个 API,超低学习成本,支持渐进式组件化. 开始接入 在项目根目录的 build.gradle 文件中添加 AppJoi ...

  5. Spring Boot 极简集成 Shiro

    点击关注公众号,Java干货及时送达 1. 前言 Apache Shiro是一个功能强大且易于使用的Java安全框架,提供了认证,授权,加密,和会话管理. Shiro有三大核心组件: Subject: ...

  6. 7句话让Codex给我做了个小游戏,还是极简版塞尔达,一玩简直停不下来

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 梦晨 萧箫 发自 凹非寺 量子位 | 公众号 QbitAI 什么,7 ...

  7. 30个Python常用极简代码,拿走就用

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨Fatos Morina 来源丨Python 技术 编辑丨极市 ...

  8. 30 段极简 Python 代码:这些小技巧你都 Get 了么?

    选自 | towardsdatascienc 编译 | 机器之心 学 Python 怎样才最快,当然是实战各种小项目,只有自己去想与写,才记得住规则.本文是 30 个极简任务,初学者可以尝试着自己实现 ...

  9. 《Kotlin极简教程》第三章 Kotlin基本数据类型

    正式上架:<Kotlin极简教程>Official on shelves: Kotlin Programming minimalist tutorial 京东JD:https://item ...

  10. 10分钟手撸极简版ORM框架!

    最近很多小伙伴对ORM框架的实现很感兴趣,不少读者在冰河的微信上问:冰河,你知道ORM框架是如何实现的吗?比如像MyBatis和Hibernte这种ORM框架,它们是如何实现的呢? 为了能够让小伙伴们 ...

最新文章

  1. 如何通过五个简单步骤成为更好的Stack Overflow用户
  2. plsql查询数据中文乱码
  3. visual studio 安装Entity framework失败
  4. STL学习笔记(数值算法)
  5. 【分享】WeX5的正确打开方式(5)——绑定机制
  6. 六年打磨!阿里开源混沌工程工具 ChaosBlade
  7. 转为yaml python_python 如何使用HttpRunner做接口自动化测试
  8. Win7 单机Spark和PySpark安装
  9. 有啥不同?来看看Spring Boot 基于 JUnit 5 实现单元测试
  10. java面向对象编程基础实验报告_20155313 实验三《Java面向对象程序设计》实验报告...
  11. pandas重置索引的几种方法探究
  12. 开博啦——半路出家做运维以来的一些杂感
  13. 孙鑫VC学习笔记:第十六讲 (三) 用异步套接字编写聊天程序
  14. 在vs中进行qt桌面应用开发时,编译器堆溢出的编译错误(error C1060编译器堆内存不足)
  15. Qt + 运动控制 (固高运动控制卡)【1】环境准备,框架搭建
  16. 74LS85 比較器 【数字电路】
  17. 2003计算机应用基础题答案,计算机应用基础(Windows_XP+Office_2003)课后题答案
  18. C语言实现输出最长的名字
  19. SpringBoot+Vue项目实现身体健康诊疗系统
  20. electron入门——安装及创建项目

热门文章

  1. 基于MATLAB的500kV LCC-HVDC 输电仿真 两侧交流系统电压为345kV,交流侧分别设计了相应的滤波器
  2. 隐式超级构造函数Fu()未定义。
  3. (已更新)看图猜成语小程序源码+详细搭建教程
  4. 网狐荣耀之微星棋牌系列,NET_PW_AgentBalance存储过程源码
  5. 文字转换表格大法、批量去除选择题选项
  6. 4K分辨率搭配光学变焦功能,极米H6成旗舰家用投影首选
  7. php添加超链接到html,总结几种实现超链接html代码
  8. 渗透之——利用Metasploit找出SCADA服务器
  9. 画江湖盟主侠岚篇怎么在电脑上玩 画江湖盟主电脑版玩法教程
  10. Clion开发STM32之OTA升级模块(最新完整版)