利用莎士比亚数据集进行RNN文本生成的训练

import tensorflow as tf
import numpy as np
from tensorflow import keras
import pandas as pd
import sklearn
import sys
import os
import matplotlib.pyplot as plt
import matplotlib as mpl
from sklearn.preprocessing import StandardScalerprint(tf.__version__)
print(sys.version_info)
for module in mpl,np,pd,sklearn,tf,keras:print(module.__name__,module.__version__)#莎士比亚数据集:https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt
input_filepath = "./shakespeare.txt"
text = open(input_filepath,'r').read()
print(len(text))
print(text[0:100])#1.生成词表
#2.映射 char -->id
#3.data -->id_data
#4.abcd -->bcd<eos>:预测下一个字符
vocab = sorted(set(text))
print(len(vocab))
print(vocab)char2idx = {char:idx for idx, char in enumerate(vocab)}
print(char2idx)
idx2char = np.array(vocab)
print(idx2char)#对text中每个字符都做一个映射
text_as_int = np.array([char2idx[c] for c in text])
print(text_as_int[0:10])
print(text[0:10])# 定义输入输出函数
def split_input_target(id_text):"""abcde -->输入abcd,输出bcde"""return id_text[0:-1], id_text[1:]# 将it_text转为dataset
char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)
seq_length = 100
seq_dataset = char_dataset.batch(seq_length + 1, drop_remainder=True)  # 当做batch操作时,如果最后一个长度不都,就丢掉
# 取出ch_id对应的字符
for ch_id in char_dataset.take(2):print(ch_id, idx2char[ch_id.numpy()])
# 取出seq_id对应的字符
for seq_id in seq_dataset.take(2):print(seq_id)print(repr(' '.join(idx2char[seq_id.numpy()])))seq_dataset = seq_dataset.map(split_input_target)
for item_input,item_output in seq_dataset:print(item_input.numpy())print(item_output.numpy())batch_size = 64
buffer_size = 10000
seq_dataset = seq_dataset.shuffle(buffer_size).batch(batch_size,drop_remainder=True)#定义模型
vocab_size = len(vocab)
embedding_dim = 256
rnn_units = 1024
#模型函数
def build_model(vocab_size,embedding_dim,rnn_units,batch_size):model = keras.models.Sequential([keras.layers.Embedding(vocab_size,embedding_dim,batch_input_shape = [batch_size,None]),keras.layers.SimpleRNN(units = rnn_units,return_sequences=True),keras.layers.Dense(vocab_size),])return modelmodel = build_model(vocab_size=vocab_size,embedding_dim=embedding_dim,rnn_units=rnn_units,batch_size=batch_size)model.summary()for input_example_batch,target_example_batch in seq_dataset.take(1):example_batch_predictions = model(input_example_batch)print(example_batch_predictions.shape)#随机采样
#在计算分类任务softmax之前的那个值就是logits
sample_indices = tf.random.categorical(logits=example_batch_predictions[0],num_samples = 1)
print(sample_indices)
#将(100,1)转换为(100,)形式
sample_indices = tf.squeeze(sample_indices,axis=-1)
print(sample_indices)#定义模型的损失函数
def loss(labels,logits):return keras.losses.sparse_categorical_crossentropy(labels,logits,from_logits=True)model.compile(optimizer = 'adam', loss = loss)
example_loss = loss(input_example_batch,example_batch_predictions)
print(example_loss.shape)
print(example_loss.numpy().mean())# 保存模型
output_dir = "./text_generation_checkpoints"
if not os.path.exists(output_dir):os.mkdir(output_dir)checkpoint_prefix = os.path.join(output_dir, 'ckpt_{epochs}')
checkpoint_callback = keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,save_weights_only=True, )epochs = 100
history = model.fit(seq_dataset, epochs=epochs,callbacks=[checkpoint_callback])#导入模型
model2 = build_model(vocab_size,embedding_dim,rnn_units,batch_size=1)
model2.load_weights(tf.train.latest_checkpoint(output_dir))
# 1:指一个样本
model2.build(tf.TensorShape([1,None]))# 文本生成的流程
# start ch sequence A,
# A -->model -->b
# A.append(b) -->B -->model -->c -->B.appden(c) -->C(abc).....
def generate_text(model, start_string, num_generate=1000):input_eval = [char2idx[ch] for ch in start_string]# 维度扩展,因为模型的输入时一个[1,None]的矩阵,而此时是一维的input_eval = tf.expand_dims(input_eval, 0)text_generated = []model.reset_states()for _ in range(num_generate):# 1.model inference --> prediction# 2.sample --> ch --> text_generated# 3.update input_eval# predictions : [batch_size,input_eval_len,vocab_size]predictions = model(input_eval)# 去掉第一维: [input_eval_len,vocab_size]predictions = tf.squeeze(predictions, 0)# predictions : [input_eval_len,1]predicted_id = tf.random.categorical(predictions, num_samples=1)[-1, 0].numpy()text_generated.append(idx2char[predicted_id])input_eval = tf.expand_dims([predicted_id], 0)return start_string + ' '.join(text_generated)new_text = generate_text(models, "All: ")
print(new_text)

利用莎士比亚数据集进行RNN文本生成的训练相关推荐

  1. 莎士比亚数据集_如何使用深度学习写莎士比亚

    莎士比亚数据集 "开玩笑地说了许多真实的话." ―威廉·莎士比亚, 李尔王 "噢,主啊,当心嫉妒: 是嘲笑的绿眼怪物 它以肉为食." ― 奥赛罗的威廉·莎士比亚 ...

  2. 分布式环境下的莎士比亚数据集处理

    项目要求 对莎士比亚语料库处理,输出统计数据: 语料库中唯一(或不同)术语的数量 语料库中以字母T / t开头的单词数 出现少于5次的术语数量 整体读取的文件数 最常出现的5个术语及其词频 实现思路 ...

  3. 利用GPT2生成莎士比亚写作风格的文本(python实现)

    一:原理 在此仅仅是简单介绍,还需要读者对self-attention.Transformer.GPT有一定的知识储备. 原始的 transformer 论文引入了两种类型的 transformer ...

  4. tensorflow循环神经网络(RNN)文本生成莎士比亚剧集

    tensorflow循环神经网络(RNN)文本生成莎士比亚剧集 我们将使用 Andrej Karpathy 在<循环神经网络不合理的有效性>一文中提供的莎士比亚作品数据集.给定此数据中的一 ...

  5. 如何使用深度学习写莎士比亚

    "开玩笑地说了许多真实的话." ―威廉·莎士比亚, 李尔王 "噢,主啊,当心嫉妒: 是嘲笑的绿眼怪物 它以肉为食." ― 奥赛罗的威廉·莎士比亚 "有 ...

  6. 使用TensorFlow.js的AI聊天机器人六:生成莎士比亚独白

    目录 设置TensorFlow.js代码 小莎士比亚数据集 通用句子编码器 莎士比亚独白在行动 终点线 总结 下载项目代码-9.9 MB TensorFlow+JavaScript.现在,最流行.最先 ...

  7. 如何用RNN生成莎士比亚风格的句子?(文末赠书)

    作者 | 李理,环信人工智能研发中心vp,十多年自然语言处理和人工智能研发经验.主持研发过多款智能硬件的问答和对话系统,负责环信中文语义分析开放平台和环信智能机器人的设计与研发. 来源 | <深 ...

  8. 基于tflearn的RNN模仿莎士比亚写作

    生成类似莎士比亚写作的文章 1.安装准备: 安装tflearn,是一个封装高的TensorFlow高层框架 pip install -I tflearn 2.实现过程 第一步:下载莎士比亚写作文本 i ...

  9. 深度学习,使用RNN的NLP,您可以成为下一个莎士比亚吗?

    是否想过智能键盘上的预测键盘之类的工具如何工作?在本文中,探讨了使用先验信息生成文本的想法.具体来说,将使用Google Colab上的递归神经网络(RNN)和自然语言处理(NLP),从16世纪文献中 ...

最新文章

  1. 物联网电子标签助力无人便利店
  2. MATLAB 的条件分支语句
  3. 软件生成目录没有图框_图纸目录和编号
  4. 阿里开源富容器引擎 PouchContainer 的 network 连接机制
  5. mysql去除重复数据 重建表_删除掉mysql 的.ibd,.frm,ibdata1,ib_logfile0和ib_logfile1文件后再drop表。然后重建此表,有问题吗...
  6. Python 实现单例模式
  7. collections模块的Counter类
  8. 第二百六十四节,Tornado框架-基于正则的动态路由映射分页数据获取计算
  9. python基础刷题_数据结构与算法LeetCode刷题(Python)
  10. Pycharm快捷键设置(鼠标滚动控制字体大小)
  11. IDEA 不识别的MAVEN 项目应如何处理
  12. java右移位_Java移位运算符详解实例
  13. Android Studio个人使用记录
  14. select函数介绍
  15. 必读 | 一文看尽2019-2020各大顶会GNN论文(附链接)
  16. 10000小时=1万小时
  17. matlab求球的体积,【matlab计算不规则物体体积资讯】matlab计算不规则物体体积足球知识与常识 - 足球百科 - 599比分...
  18. 数据类型---C语言变量的定义与初始化
  19. 【数值模型系列】link_grib.csh脚本解读
  20. windows installer 窗口一直”正在取消“,无法关闭

热门文章

  1. 基本算法练习-约德尔测试
  2. 开板季滑雪热度暴涨,小红书“滑雪”搜索量涨150%
  3. 软件设计师之像素点计算
  4. VisualFreeBasic链接mysql数据库用法
  5. 外键约束(foreign key) [MySQL][数据库]
  6. ElementUI:表格table列宽度压缩出现空白
  7. python里else中文意思_Python中被忽略的else
  8. aspose插入word
  9. SELinux is preventing /usr/sbin/httpd from name_bind access on the tcp_socket port
  10. 英雄联盟对战,为求公平需要选取两组分值相差最低的队伍