代码:https://github.com/MONI-JUAN/Tensorflow_Study/15-17——RNN-LSTM-生成文本

先行知识点:

TensorFlow 15——ch12-RNN、LSTM基本结构

TensorFlow 16——ch12-RNN 和 LSTM 的实现方式

目录

  • 一、函数定义
    • 1.定义输入数据
    • 2.定有多层LSTM模型
    • 3.定义损失
  • 二、训练模型
    • 1.生成英文
    • 2.生成诗词
    • 3.生成C代码

一、函数定义

1.定义输入数据

model.py

def build_inputs(self):with tf.name_scope('inputs'):# inputs 的形状和 targets 相同,都为(num_seqs,num_steps)# num_seqs 为一个 batch 内的句子个数# num_steps 为每个句子的长度self.inputs = tf.placeholder(tf.int32, shape=(self.num_seqs, self.num_steps), name='inputs')self.targets = tf.placeholder(tf.int32, shape=(self.num_seqs, self.num_steps), name='targets')# keep_prob 控制了 Dropout 层所需要的概率(训练0.5,测试1.0)self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')# 对于中文,需要使用embedding层,英文不用if self.use_embedding is False:self.lstm_inputs = tf.one_hot(self.inputs, self.num_classes)else:with tf.device("/cpu:0"):embedding = tf.get_variable('embedding', [self.num_classes, self.embedding_size])self.lstm_inputs = tf.nn.embedding_lookup(embedding, self.inputs)

2.定有多层LSTM模型

model.py

def build_lstm(self):# 创建单个cell并堆叠多层,每一层还加入了Dropout减少过拟合def get_a_cell(lstm_size, keep_prob):lstm = tf.nn.rnn_cell.BasicLSTMCell(lstm_size)drop = tf.nn.rnn_cell.DropoutWrapper(lstm, output_keep_prob=keep_prob)return dropwith tf.name_scope('lstm'):cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell(self.lstm_size, self.keep_prob) for _ in range(self.num_layers)])self.initial_state = cell.zero_state(self.num_seqs, tf.float32)# 通过dynamic_rnn对cell展开时间维度self.lstm_outputs, self.final_state = tf.nn.dynamic_rnn(cell, self.lstm_inputs, initial_state=self.initial_state)# 通过lstm_outputs得到概率seq_output = tf.concat(self.lstm_outputs, 1)x = tf.reshape(seq_output, [-1, self.lstm_size])with tf.variable_scope('softmax'):softmax_w = tf.Variable(tf.truncated_normal([self.lstm_size, self.num_classes], stddev=0.1))softmax_b = tf.Variable(tf.zeros(self.num_classes))# proba_prediction = Softmax(Wx+b)self.logits = tf.matmul(x, softmax_w) + softmax_bself.proba_prediction = tf.nn.softmax(self.logits, name='predictions')

3.定义损失

def build_loss(self):with tf.name_scope('loss'):y_one_hot = tf.one_hot(self.targets, self.num_classes)y_reshaped = tf.reshape(y_one_hot, self.logits.get_shape())loss = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=y_reshaped)self.loss = tf.reduce_mean(loss)

二、训练模型

1.生成英文

训练生成英文的模型:

python train.py \--input_file data/shakespeare.txt \--name shakespeare \--num_steps 50 \--num_seqs 32 \--learning_rate 0.01 \--max_steps 20000
python train.py  --input_file data/shakespeare.txt --name shakespeare --num_steps 50 --num_seqs 32 --learning_rate 0.01 --max_steps 20000

测试模型:

python sample.py \--converter_path model/shakespeare/converter.pkl \--checkpoint_path model/shakespeare/ \--max_length 1000
python sample.py --converter_path model/shakespeare/converter.pkl --checkpoint_path model/shakespeare/ --max_length 1000

因为每次候选下一个字母都是top5做概率归一后挑出的,所以文本生成的结果都会不同。

top5的代码看这里:python概率选取ndarray的TOP-N

真的好神奇,很难想象才20000步的效果会那么好!

2.生成诗词

训练写诗模型:

python train.py \--use_embedding \--input_file data/poetry.txt \--name poetry \--learning_rate 0.005 \--num_steps 26 \--num_seqs 32 \--max_steps 10000
python train.py --use_embedding --input_file data/poetry.txt --name poetry --learning_rate 0.005 --num_steps 26 --num_seqs 32 --max_steps 10000


测试模型:

python sample.py \--use_embedding \--converter_path model/poetry/converter.pkl \--checkpoint_path model/poetry/ \--max_length 300
python sample.py --use_embedding --converter_path model/poetry/converter.pkl --checkpoint_path model/poetry/ --max_length 300

3.生成C代码

训练生成C代码的模型:

python train.py \--input_file data/linux.txt \--num_steps 100 \--name linux \--learning_rate 0.01 \--num_seqs 32 \--max_steps 20000
python train.py --input_file data/linux.txt --num_steps 100 --name linux --learning_rate 0.01 --num_seqs 32 --max_steps 20000

测试模型:

python sample.py \--converter_path model/linux/converter.pkl \--checkpoint_path model/linux \--max_length 1000
python sample.py --converter_path model/linux/converter.pkl --checkpoint_path model/linux --max_length 1000

TensorFlow 17——ch12-Char RNN 文本生成(莎士比亚/诗词)相关推荐

  1. tensorflow循环神经网络(RNN)文本生成莎士比亚剧集

    tensorflow循环神经网络(RNN)文本生成莎士比亚剧集 我们将使用 Andrej Karpathy 在<循环神经网络不合理的有效性>一文中提供的莎士比亚作品数据集.给定此数据中的一 ...

  2. Tensorflow2.0之文本生成莎士比亚作品

    文章目录 1.导入数据 2.创建模型 3.训练 3.1 编译模型 3.2 配置检查点 3.3 训练模型 4.预测 4.1 重建模型 4.2 生成文本 我们将使用 Andrej Karpathy 在&l ...

  3. 如何用RNN生成莎士比亚风格的句子?(文末赠书)

    作者 | 李理,环信人工智能研发中心vp,十多年自然语言处理和人工智能研发经验.主持研发过多款智能硬件的问答和对话系统,负责环信中文语义分析开放平台和环信智能机器人的设计与研发. 来源 | <深 ...

  4. 使用TensorFlow.js的AI聊天机器人六:生成莎士比亚独白

    目录 设置TensorFlow.js代码 小莎士比亚数据集 通用句子编码器 莎士比亚独白在行动 终点线 总结 下载项目代码-9.9 MB TensorFlow+JavaScript.现在,最流行.最先 ...

  5. 利用GPT2生成莎士比亚写作风格的文本(python实现)

    一:原理 在此仅仅是简单介绍,还需要读者对self-attention.Transformer.GPT有一定的知识储备. 原始的 transformer 论文引入了两种类型的 transformer ...

  6. Tensorflow快餐教程(12) - 用机器写莎士比亚的戏剧

    高层框架:TFLearn和Keras 上一节我们学习了Tensorflow的高层API封装,可以通过简单的几步就生成一个DNN分类器来解决MNIST手写识别问题. 尽管Tensorflow也在不断推进 ...

  7. GRU网络生成莎士比亚小说

    介绍 本文我们将使用GRU网络来学习莎士比亚小说,模型通过学习可以生成与小说风格相似的文本,如图所示: 虽然有些句子并没有实际的意思(目前我们的模型是基于概率,并不是理解语义),但是大多数单词都是有效 ...

  8. 【Github上有趣的项目】基于RNN文本生成器,自动生成莎士比亚的剧本或者shell代码(不是python的是lua的)

    文章目录 下了之后才发现不是python的尴尬得一匹,,ԾㅂԾ,, GitHub 上有哪些有趣的关于 NLP 或者 DL 的项目? - Xiaoran的回答 - 知乎 char-rnn 下了之后才发现 ...

  9. 使用LSTM进行莎士比亚风格诗句生成

    本文章跟本人前面两篇文章(文章1, 文章2)的思路大体相同,都是使用序列化的数据集来训练RNN神经网络模型,然后自动生成相关的序列化.这篇文章使用莎士比亚诗词作为训练集,使用keras和tensorf ...

最新文章

  1. PyTorch深度学习
  2. Dreamoon and Ranking Collection CodeForces - 1330A (贪心)
  3. 分区式存储管理c++_分区机要变形缝,纵横交接卫浴厨:防火阀参数的高效记忆口诀...
  4. 怎样让计算机恢复到桌面上,如何把电脑桌面恢复成原样.怎么办?
  5. 荷兰政府用大数据预测天气预防自然灾害,他们是怎么做的?
  6. mysql 删除 like_MySQL 定时删除数据
  7. 理想汽车市值逼近蔚来,王兴曾多次在饭否为其站台
  8. @media实现网页自适应中的几个关键分辨率
  9. 《深入理解Spark:核心思想与源码分析》——1.3节阅读环境准备
  10. 谨以此文献给才毕业2--5年的朋友(转)
  11. MRP游戏软件常见问题解答以及破解方法!(新手必看)
  12. PointCloudLibrary点云库介绍
  13. 路由器工作原理及其主要部件详解
  14. 李兴华内部JAVA培训视频 (难找啊)
  15. 微信服务号开发-获取用户位置信息
  16. 微信小程序仿淘票票之登录注册讲解
  17. php短信炸弹,php发送短信炸弹
  18. PC串口状态监视软件
  19. Linux系统使用LAMP架构部署Discuz论坛系统,简洁明了
  20. ios开发者账号添加受信任电话号码

热门文章

  1. HSSFWorkbook poi创建锁定的单元格
  2. 没有合同被私人老板拖欠工资要如何处理
  3. 跨境电商独立站的运营技巧都有哪些?
  4. 卓训教育:孩子不爱阅读怎么办,家长可以这样培养孩子的阅读习惯
  5. 【大数据处理技术】实验5
  6. 掌财社:掌握CCI指标捕捉爆发牛股
  7. 国家动物博物馆参观记
  8. 2021三季报业绩预增一览
  9. 你全会算我输,让人直呼卧槽的Python代码!
  10. 在img标签限制图片大小