代码连链接

这里稍微更改了下《深度学习框架PyTorch:入门与实践》里的demo,去掉稍微繁琐和多次训练的部分,只保留了比较核心的生成连接诗和藏头诗两部分(比较渣太复杂了看不懂)。

目标效果:

连接诗:
机器学习书,局上无酒浆。婆娑珍金盘,缕烂金葳浆。萱草发秋叶,旖旎镂金墙。拳芳既盈薄,禄位不敢匼。揆我不及饱,有时不相并。我为蘧生意,日出狱所宣。徇禄惩未卜,退食何由尝。濡毫若可濯,视事忘忧伤。茍餐偿朽援,亦有糟醨浆。忆来不得意,与我来此方。幽谷荫幽壤,清猨鸣秋霜。因思霖雨霁,始觉风物长。既与物外违,复知辱所并。吾欲顾己仁,岂必死者伤。茍哉吾安适,复道迈一觞。虽遇物外兴,永忻良所侵。藏头诗:
机心俨不见,金马不可量。器食非所古,此心何足伤。学书三十载,相劝一旬长。习隸不可乐,持斧亦葳裳。书来不可再,稅首皆自忙。

主要面main.py:

import torch as t
from data import get_data
from model import PoetryModel"""对象"""
class Config(object):data_path = 'data/'  # 诗歌的文本文件存放路径pickle_path = 'tang.npz'  # 预处理好的二进制文件,包含data,形状为(57580,125),共57580首诗歌,每首诗歌长度125,不够补空格,多余丢弃category = 'poet.tang'  # 类别,唐诗还是宋诗歌(poet.song)max_gen_len = 200  # 生成诗歌最长长度prefix_words = '细雨鱼儿出,微风燕子斜。'  # 不是诗歌的组成部分,用来控制生成诗歌的意境start_words = '闲云潭影日悠悠'  # 诗歌开始acrostic = True  # 是否是藏头诗model_path = None  # 预训练模型路径
opt = Config()"""给定一句诗,继续生成一首完整的诗"""
def generate(model, start_words, ix2word, word2ix, prefix_words=None):"""ix2word:每个序号对应的字word2ix:每个字对应的序号prefix_words:诗歌意境<EOP>:8290<START>:8591</s>:8292<,>:7066<。>:7435"""results = list(start_words)start_word_len = len(start_words)input = t.Tensor([word2ix['<START>']]).view(1, 1).long()    # 手动设置第一个词为<START>hidden = Noneif prefix_words:   #意境诗句存在,这里主要用于训练记忆hidden,生成的output无用for word in prefix_words:output, hidden = model(input, hidden)input = input.data.new([word2ix[word]]).view(1, 1)for i in range(opt.max_gen_len):   #最大的句子长度output, hidden = model(input, hidden)   #前边几个output无用,因为默认使用前缀诗句if i < start_word_len:   #如果小于前缀句子的长度w = results[i]   #取前缀对应位置的字input = input.data.new([word2ix[w]]).view(1, 1)   #取出前缀诗句对应的字的下标,形成1*1的矩阵else:   #已经输出完前缀了,该保存output中计算的字了top_index = output.data[0].topk(1)[1][0].item()   #保存output中概率最大的那个字的下标w = ix2word[top_index]   #取出这个下标对应的字results.append(w)   #生成的诗句序列加入刚生成的结果input = input.data.new([top_index]).view(1, 1)   #把当前字的下标扩充成1*1当作下一次的输入if w == '<EOP>':   #如果预测到了结束,删掉并退出del results[-1]breakreturn results"""给定一句诗,对应生成藏头诗"""
def gen_acrostic(model, start_words, ix2word, word2ix, prefix_words=None):"""ix2word:每个序号对应的字word2ix:每个字对应的序号prefix_words:诗歌意境<EOP>:8290<START>:8591</s>:8292<,>:7066<。>:7435"""results = []   #用于存储生成的诗歌start_word_len = len(start_words)   #用于生成藏头诗的诗句的长度(有几个字生成几句)input = (t.Tensor([word2ix['<START>']]).view(1, 1).long())   #word2ix['<START>']=8291,一开始input先赋值<START>的序号hidden = Noneindex = 0  # 用来指示已经生成了多少句藏头诗pre_word = '<START>'   #用来表示上一个词,第一个词设置为'<START>'if prefix_words:   #如果设置了意境诗句,其中逗号和句号也包括在内for word in prefix_words:   #遍历这句诗的每一个字output, hidden = model(input, hidden)   #input是1*1矩阵,只包含上一个字,用于输进去进行降维,和一个空的hidden记忆,输出关于当前字的记忆hidden和当前的输出output,这个output貌似没用input = (input.data.new([word2ix[word]])).view(1, 1)   #将当前遍历的字的编号取出形成一个1*1矩阵赋值给input,input.data.new([word2ix[word]])是将一个数字变成一维tensor,view(1,1)是变成2维tensorfor i in range(opt.max_gen_len):   #生成诗歌最长长度output, hidden = model(input, hidden)   #对之前意境的最后一个字进行处理,这里其实是句号,output代表对下一个字的预测概率top_index = output.data[0].topk(1)[1][0].item()   #topk(1)取出向量中一个最大值,返回概率和对应下标w = ix2word[top_index]   #获取几率最大的下标对应的字if (pre_word in {'。', '!', '<START>'}): # 如果遇到句号叹号或者开始符号,藏头的词送进去生成if index == start_word_len:   #index用来指示已经生成了多少句藏头诗,如果相等则代表已经生成完了,退出循环breakelse: # 把藏头的词作为输入送入模型w = start_words[index]   #取藏头诗中第index个下标对应的字(这里如果是每一句开头,会抛弃上一句最后取出来的最大概率的w,重新覆盖了)index += 1input = (input.data.new([word2ix[w]])).view(1, 1)   #取对应字的数字编号作为一个1*1矩阵,作为下一个词的输入else: # 否则的话,把上一次预测是词作为下一个词输入input = (input.data.new([word2ix[w]])).view(1, 1)results.append(w)   #结果诗句添加当前字pre_word = w   #将w赋值给变量作为前一个单词return results"""提供接口,处理命令选择生成诗句类型"""
def gen(kwargs):for k, v in kwargs.items():   #遍历出来的k是键,v是值setattr(opt, k, v)   #设置属性值,第一个参数是对象,第二个参数是属性,第三个参数是属性值data, word2ix, ix2word = get_data(opt)   #第一个参数每一行是一首诗对应的字的下标,第二个参数是每个字对应的序号,第三个参数是每个序号对应的字model = PoetryModel(len(word2ix), 128, 256)map_location = lambda s, l: s   #这三句貌似是将模型加载在GPU上state_dict = t.load(opt.model_path, map_location=map_location)model.load_state_dict(state_dict)gen_poetry = gen_acrostic if opt.acrostic else generate   #是不是藏头诗,gen_acrostic为将提示句拆成每句第一个字生成藏头诗,generate为将提示句作为第一句result = gen_poetry(model, opt.start_words, ix2word, word2ix, opt.prefix_words)   #ix2word为每个序号对应的字,word2ix为每个字对应的序号,prefix_words为师哥意境print(''.join(result))"""主方法"""
def test():gen({'model_path': 'checkpoints/tang_199.pth', 'pickle_path': 'tang.npz', 'start_words': '机器学习书','prefix_words': '床前明月光,疑是地上霜。', 'acrostic': True, 'nouse_gpu': False})test()

获取数据data.py:

# coding:utf-8
import os
import numpy as np"""从二进制文件中获取诗歌数据"""
def get_data(opt):   #读取二进制numpy文件"""opt 配置选项 Config对象dict,每个字对应的序号,形如u'月'->100ix2word: dict,每个序号对应的字,形如'100'->'月'data: numpy数组,每一行是一首诗对应的字的下标"""if os.path.exists(opt.pickle_path):   #判断路径是否存在,即tang.npz这个文件是否存在(npz为二进制文件)data = np.load(opt.pickle_path,allow_pickle=True)   #读取tang.npz中的数据存到data中data, word2ix, ix2word = data['data'], data['word2ix'].item(), data['ix2word'].item()   #分类获取数据return data, word2ix, ix2word

模型model.py:

# coding:utf-8
import torch
import torch.nn as nn"""诗歌模型"""
class PoetryModel(nn.Module):def __init__(self, vocab_size, embedding_dim, hidden_dim):"""vocab_size: 所有字的长度8293embedding_dim: 词嵌入的维度128hidden_dim: 隐藏层向量维度,即隐藏层节点个数256"""super(PoetryModel, self).__init__()self.hidden_dim = hidden_dimself.embeddings = nn.Embedding(vocab_size, embedding_dim)   #权重矩阵为vocab_size*embedding_dim,即8293*128self.lstm = nn.LSTM(embedding_dim, self.hidden_dim, num_layers=2)   #第三个参数为网络层数self.linear1 = nn.Linear(self.hidden_dim, vocab_size)   #全连接层,生成一个词汇表,词汇表的数值是概率,代表这句话下一个位置这个单词出现的概率def forward(self, input, hidden=None):seq_len, batch_size = input.size()   #input是一个一个的数字,seq_len表示每一句话多长(有多少单词),batcg_size表示一次处理几个句子if hidden is None:   #一开始设置h_0和c_0都为0h_0 = input.data.new(2, batch_size, self.hidden_dim).fill_(0).float()   #隐藏元,维度为(2*batch_size*hidden_dim),全部填充为0,这里的shape为(2,1,256)c_0 = input.data.new(2, batch_size, self.hidden_dim).fill_(0).float()else:h_0, c_0 = hiddenembeds = self.embeddings(input)   #(1,1,128)output, hidden = self.lstm(embeds, (h_0, c_0))   #output为(1,1,256),hidden中包含了同为(1,1,256)的h_0和c_0output = self.linear1(output.view(seq_len * batch_size, -1))   #全连接后变成(1,8293)return output, hidden

深度学习-使用RNN生成诗相关推荐

  1. 深度学习中的生成对抗网络GAN

    转载:一文看尽深度学习中的生成对抗网络 | CVHub带你看一看GANs架构发展的8年 (qq.com) 导读 生成对抗网络 (Generative Adversarial Networks, GAN ...

  2. TensorFlow2学习:RNN生成古诗词

    本文转自 AI科技大本营 TensorFlow2学习:RNN生成古诗词 文章不见了可以参考这位博主的文章 公众号的文章写得挺好的,这里简单介绍下思路及值得学习的地方 模型简介 模型不算多么复杂高大上, ...

  3. 《Deep Learning Techniques for Music Generation – A Survey》深度学习用于音乐生成——书籍阅读笔记(一)Chapter 1

    <Deep Learning Techniques for Music Generation – A Survey>深度学习用于音乐生成--书籍阅读笔记(一)Chapter 1 关于这本书 ...

  4. 基于深度学习的宋词生成

    <自然语言处理>课程报告 摘 要 宋词是一种相对于古体诗的新体诗歌之一,为宋代儒客文人智慧精华,标志宋代文学的最高成就.宋词生成属于自然语言处理领域的文本生成模块,当前文本生成领域主要包括 ...

  5. 推荐基于深度学习实时同步生成2D动画口型算法

    概述 实时二维动画是一种相当新颖而强大的交流形式,它使表演者可以实时控制卡通人物,同时与其他演员或观众互动和即兴表演. 最近的例子包括史蒂芬·科尔伯特(Stephen Colbert)在<后期秀 ...

  6. 一言不合就想斗图?快用深度学习帮你生成表情包

    源 | AI研习社 AI研习社:斯坦福大学的两个学生 Abel L Peirson V 和 Meltem Tolunay 发表了自己的 CS224n 结业论文-- 用深度神经网络生成表情包(你没有看错 ...

  7. 极限元语音算法专家刘斌:基于深度学习的语音生成问题

    一.深度学习在语音合成中的应用 语音合成主要采用波形拼接合成和统计参数合成两种方式.波形拼接语音合成需要有足够的高质量发音人录音才能够合成高质量的语音,它在工业界中得到了广泛使用.统计参数语音合成虽然 ...

  8. [人工智能-深度学习-59]:生成对抗网络GAN - 基本原理(图解、详解、通俗易懂)

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/detai ...

  9. [人工智能-深度学习-63]:生成对抗网络GAN - 图片创作:普通GAN, pix2pix, CycleGAN和pix2pixHD的演变过程

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/detai ...

最新文章

  1. 2位华人获得加州理工学院计算机、数学博士奖学金,3年近一半由华人获得
  2. 封装构造函数,用canvas写饼状图和柱状图
  3. XGBoost:Python下 安装
  4. 网络营销常用工具与资源
  5. .NET中使用Redis总结 —— 1.Redis搭建
  6. 把你的 VS Code 打造成 C++ 开发利器
  7. C++设计模式-备忘录模式
  8. bootstrap table 主子表 局部数据刷新(刷新子表)
  9. python简单代码-用Python代码实现5种最好的、简单的数据可视化!
  10. 用激光把谷歌的标志投射到月球是否可行?
  11. Gsonformat插件安装与使用
  12. Apache HttpClient4使用教程
  13. matlab选址问题——分级选址定容
  14. Android Car - 开机画面
  15. MFC中利用ListControl制作空表格,由键盘输入数据并保存在数组中
  16. 教你前端如何用js写一个跑酷小游戏
  17. 人工智能的发展历程和未来发展趋势
  18. Python map
  19. python mitmproxy +雷电模拟器 安装
  20. OPENFILER构建软iSCSI multipath实现多路径聚合(一)

热门文章

  1. python sobel滤波_sobel滤波器在imag中的应用
  2. 分享 :统计学概论和医疗临床大数据分析(附PPT下载)
  3. 长时间穿高跟鞋存在健康隐患
  4. 关于银行系统专项测试,你了解多少?学习一下
  5. 免费的系统安全测试软件PC Security Test
  6. 轻而易举的攻过tomcat
  7. springboot集成webjars
  8. 计算机图形学直线段的生成算法
  9. 编译Android源码卡死,编译Android源码过程中出现的错误
  10. 中国达沃斯域名,被大连小伙成功注册www.chinadavos.com(转)