文章目录

  • 前言
  • 1. 模型介绍
    • 1.1 Encoder-Decoder框架
    • 1.2 Attention机制
    • 1.3 代码实现
  • 2. 安装依赖库
  • 3. 模型部署
  • 4. 测试

前言

  哈哈,重头戏终于来了,经过两天的服务器配置、模型训练,今天终于在微信公众号上部署了自己使用TensorFlow训练的聊天机器人。
  本篇博客主要介绍一下Seq2Seq模型,以及模型训练后的部署,使用的深度学习框架为TensorFlow2.1,GPU为Tesla P100(白嫖Kaggle的),由于网站有时间限制,只训练了两个epoch就先部署了哈,所以机器人目前还很沙雕。

  有关腾讯云服务器配置流程和Django对接微信公众号以实现消息自动回复可以参考这两篇博客。

1. 模型介绍

  Seq2SeqSeq2SeqSeq2Seq的全称是SequenceSequenceSequence tototo SequenceSequenceSequence,也就是我们常说的序列到序列模型,它是基于Encoder−DecoderEncoder-DecoderEncoderDecoder框架的RNN(RecurrentRNN(RecurrentRNN(Recurrent NeuralNeuralNeural Network,循环神经网络)Network,循环神经网络)Network,)变种。Seq2SeqSeq2SeqSeq2Seq引入Encoder−DecoderEncoder-DecoderEncoderDecoder框架,提高了神经网络对长文本信息的提取能力,取得了比单纯使用LSTM(LongLSTM(LongLSTM(Long Short−TermShort-TermShortTerm Memory,长短期记忆神经网络)Memory,长短期记忆神经网络)Memory,)更好的效果。Seq2SeqSeq2SeqSeq2Seq中有两个很重要的概念,一个就是上面提到的Encoder−DecoderEncoder-DecoderEncoderDecoder框架,另一个就是AttentionAttentionAttention机制。这里简单介绍一下这两个概念。

1.1 Encoder-Decoder框架

  Encoder−DecoderEncoder-DecoderEncoderDecoder又称为编码器-解码器模型,顾名思义,它有两部分组成,即编码器和解码器。它是一种处理输入、输出长短不一的多对多文本预测问题的框架,其提供了有效的文本特征提取、输出预测的机制。
  编码器的作用是对输入的文本信息进行有效的编码后,将其作为解码器的输入数据,其目的是对输入的文本信息进行特征提取,尽量准确高效地表征该文本的特征信息。
  解码器的作用是从上下文的文本信息中获取尽可能多的特征,然后输出预测文本。根据对文本信息的获取方式不同,解码器一般分为4种结构,分别是直译式解码、循环式解码、增强式解码和注意力机制解码。

  • 直译式解码:按照编码器的费那事进行逆操作得到的预测文本
  • 循环式解码:将编码器输出的编码向量作为第一时刻的输入,然后将得到的输出作为下一个时刻的输入,依次进行循环解码
  • 增强循环式解码:在循环式解码的基础上,每一时刻增加一个编码器输出的编码向量作为输入
  • 注意力机制解码:在增强式循环解码的基础上增加注意力机制,这样可以有效地训练解码器在繁多的输入中重点关注某些有效特征信息,以增加解码器的特征获取能力,进而得到更好的解码效果。

1.2 Attention机制

  虽然Encoder−DecoderEncoder-DecoderEncoderDecoder结构的模型在机器翻译、语音识别以及文本生成等诸多领域均取得了非常不错的效果,但同时也存在着不足之处。编码器将输入的序列编码成一个固定长度的向量,再由解码器将其解码,得到输出序列。但个固定长度的向量所具有的表征能力是有限的,解码器又受限于这个固定长度的向量,当输入的文本序列较长时,编码器很难将所有的重要信息都编码到这个定长的向量中,从而使得模型的输出结果大大折扣。
  AttentionAttentionAttention机制有效解决了输入长序列信息时真实含义难以获取的问题。在进行长文本序列处理的任务中,影响当前时刻状态的信息可能隐藏在前面的时刻里,根据马尔可夫假设,这些信息有可能就会被忽略掉。比如,在“我快饿死了,今天搬了一天的砖,我要大吃一顿”这句话中,我们知道“我要大吃一顿”是因为“我快饿死了”,但是基于马尔可夫假设,“今天搬了一天的砖”“我要大吃一顿”在时序上离得更近,相比于“我快饿死了”“今天搬了一天的砖”“我要大吃一顿”的影响力更强,但是在真实的NLP(NaturalNLP(NaturalNLP(Natural LanguageLanguageLanguage Processing,自然语言处理)Processing,自然语言处理)Processing,)中不是这样的。从这个例子中可以看出,神经网络模型没有办法很好地准确获取倒装时序的语言信息,要解决这个问题就需要经过训练自动建立起“我要大吃一顿”“我快饿死了”的关联关系,这就是AttentionAttentionAttention机制,即注意力机制。

1.3 代码实现

   class Encoder(tf.keras.Model):"""编码器"""def __init__(self, vocab_size, embedding_dim, enc_units, batch_size):super(Encoder, self).__init__()self.batch_size = batch_sizeself.enc_units = enc_unitsself.embedding = tf.keras.layers.Embedding(input_dim=vocab_size, output_dim=embedding_dim)self.gru = tf.keras.layers.GRU(units=self.enc_units, recurrent_initializer='glorot_uniform',return_sequences=True, return_state=True)def call(self, x, hidden):# 此处添加模型调用的代码(处理输入并返回输出)x = self.embedding(x)output, state = self.gru(inputs=x, initial_state=hidden)return output, statedef initialize_hidden_state(self):return tf.zeros(shape=(self.batch_size, self.enc_units))class BahdanauAttention(tf.keras.Model):"""Bahdanau Attention"""def __init__(self, units):super(BahdanauAttention, self).__init__()self.W1 = tf.keras.layers.Dense(units=units)self.W2 = tf.keras.layers.Dense(units=units)self.V = tf.keras.layers.Dense(units=1)def call(self, query, values):# query为Encoder最后一个时间步的隐状态(hidden), shape为(batch_size, hidden_size)# values为Encoder部分的输出,即每个时间步的隐状态,shape为(batch_size, max_length, hidden_size)# 为方便后续计算,需将query的shape转为(batch_size, 1, hidden_size)# 给query增加一个维度query = tf.expand_dims(input=query, axis=1)# 计算score(相似度), 使用MLP网络,即再引入一个神经网络来专门计算score# score的shape为(batch_size, max_length, 1)score = self.V(inputs=tf.nn.tanh(self.W1(inputs=query) + self.W2(inputs=values)))# 计算attention_weights# 计算attention_weights的shape为(batch_size, max_length, 1)attention_weights = tf.nn.softmax(logits=score, axis=1)# 计算context vector# context vector的shape为(batch_size, max_length, hidden_size)context_vector = attention_weights * values# 加权求和# 求和之后的shape为(batch_size, hidden_size)context_vector = tf.reduce_sum(input_tensor=context_vector, axis=1)return context_vector, attention_weightsclass Decoder(tf.keras.Model):"""解码器"""def __init__(self, vocab_size, embedding_dim, dec_units, batch_size):super(Decoder, self).__init__()self.batch_size = batch_sizeself.dec_units = dec_unitsself.embedding = tf.keras.layers.Embedding(input_dim=vocab_size, output_dim=embedding_dim)self.gru = tf.keras.layers.GRU(units=self.dec_units, recurrent_initializer='glorot_uniform',return_sequences=True, return_state=True)self.fc = tf.keras.layers.Dense(units=vocab_size)self.attention = BahdanauAttention(units=self.dec_units)def call(self, x, hidden, enc_output):# 获取context vector和attention weightscontext_vector, attention_weights = self.attention(hidden, enc_output)# 编码之后x的shape为(batch_size, 1, embedding_dim)x = self.embedding(inputs=x)# 将context_vector与输入x进行拼接# 拼接后的shape为(batch_size, 1, embedding_dim + hidden_size)# 这里的hidden_size即context_vector向量的长度x = tf.concat(values=[tf.expand_dims(input=context_vector, axis=1), x], axis=-1)# 拼接后输入GRU网络output, state = self.gru(inputs=x)# print("Decoder output shape: {}".format(output.shape))# print("Decoder state shape: {}".format(state.shape))# (batch_size, 1, hidden_size) ==> (batch_size, hidden_size)output = tf.reshape(tensor=output, shape=(-1, output.shape[2]))# x的shape为(batch_size, vocab_size)x = self.fc(inputs=output)return x, state, attention_weights

  我也是这学期才开始入手TensorFlow2,以前用的都是TensorFlow 1.13.1,代码不明白的地方可以查看《简单粗暴 TensorFlow 2》文档。

2. 安装依赖库

  • 安装TensorFlow 2.1
 pip3 install tensorflow==2.1.0
  • 安装jieba
 pip3 install jieba


3. 模型部署

  腾讯云服务器用的是学生版的1核2G,感觉不一定能够支撑模型运行,先尝试一下吧。在此之前还是在本地通过Postman进行一下测试:


  还是OK的,就是模型加载的较慢,下面把模型文件以及相关代码上传到服务器的项目目录,目录内容更新为如下:


  上传到服务器之后,大致等到模型差不多加载好就可以准备测试了,测试结果如下:


  查看一下日志文件,发现了一些端倪:

  进程被杀死了,查了一下相关文件,说是超时了,enmmmmm,貌似有些道理【虽然不是很确定,但是模型确实是被重新加载了,更改了相关uwsgi的参数之后依旧是这个结果】,于是我直接上传了一个更改后的测试模型文件CR.py,直接在环境中运行,果不其然:



  这应该是内存不够吧~OK,暂时到此结束。



  昨天出了一点意外,1核2G的腾讯云服务器运行不了这个模型,所以今天换成了2核4G的阿里云服务器【有一说一,阿里云的这个学生套餐还是挺实惠的,又成功白嫖】,阿里云的配置过程同腾讯云的一样,可参考我的这篇博客。
  服务器配置完成之后,把项目文件上传到阿里云服务器的wwwroot文件夹下,然后进入pyweb虚拟环境,再次运行一下CR.py文件,看看模型能不能运行起来。结果如下:


  还是很nice的,模型能够运行,OK,接入到微信公众号上,配置代码很简单,只需要把微信公众号发送过来的消息送入到模型即可,代码如下:

 # views.py# 导入模型的接口from tencent.chatRobot import predictinput_info = recMsg.Content.decode('utf-8')try:content = predict(sentence=input_info)except Exception as err:content = '小悠没理解主银的意思~'replyMsg = TextMsg(toUser, fromUser, content)

  当时,还考虑了很久,模型如何先被加载,因为模型加载的时间稍长,不能等到微信公众号消息来了再加载模型,那肯定会超时的,而且每次都加载,肯定还很麻烦。当时还考虑到用线程等方法来加载,enmmmmm,后来嘛,就突然想到,为何不用全局变量的形式来加载,就是Python执行的时候是顺序执行嘛,像函数、类之类的这种对象,虽然定义了,但只要不被调用,这些代码就不会被运行,而函数、类之外的代码会正常按顺序执行,相当于就是全局变量了嘛。

 # chatRobot.py# -*- coding: utf-8 -*-# @Time    : 2021/1/4 22:47# @Author  : XiaYouRan# @Email   : youran.xia@foxmail.com# @File    : chatRobot.py# @Software: PyCharmimport tensorflow as tfimport jiebaimport osdef preprocess_sentence(sentence):"""给句子添加开始和结束标记:param sentence::return:"""sentence = '<start> ' + sentence + ' <end>'return sentencedef max_length(tensor):"""计算数据集中问句和答句中最长的句子长度:param tensor::return:"""return max([len(t) for t in tensor])def tokenize(sentences):"""分词器函数:param sentence::return:"""# 初始化分词器,并生成词典sentence_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='')sentence_tokenizer.fit_on_texts(sentences)# 利用字典将文本数据转为id# 也是二维的tensor = sentence_tokenizer.texts_to_sequences(texts=sentences)# 将数据填充成统一长度# 默认统一为最长句子长度# 将长为nb_samples的序列(标量序列)转化为形如(nb_samples,nb_timesteps) 2D numpy arraytensor = tf.keras.preprocessing.sequence.pad_sequences(tensor, maxlen=30, padding='post')return tensor, sentence_tokenizerdef load_dataset(file_path):with open(file_path, 'r', encoding='utf-8') as f:lines = f.readlines()q = ''a = ''qa_pairs = []# len(lines) 总行数for i in range(len(lines)):if i % 3 == 0:q = ' '.join(jieba.cut(lines[i].strip()))elif i % 3 == 1:a = ' '.join(jieba.cut(lines[i].strip()))else:# 问句与答句进行组合pair = [preprocess_sentence(q), preprocess_sentence(a)]qa_pairs.append(pair)# zip 拆解q_sentences, a_sentences = zip(*qa_pairs)# question数据集(id)及其分类器词汇表q_tensor, q_tokenizer = tokenize(q_sentences)# answer数据集(id)及其分类器词汇表a_tensor, a_tokenizer = tokenize(a_sentences)return q_tensor, a_tensor, q_tokenizer, a_tokenizerclass Encoder(tf.keras.Model):"""编码器"""class BahdanauAttention(tf.keras.Model):"""Bahdanau Attention"""class Decoder(tf.keras.Model):"""解码器"""# 使用Adam优化器optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)def predict(sentence):"""模型测试"""# 加载模型checkpoint = tf.train.Checkpoint(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),encoder=encoder,decoder=decoder)checkpoint.restore(save_path=tf.train.latest_checkpoint(checkpoint_dir=checkpoint_dir))sentence = ' '.join(jieba.cut(sentence.strip()))sentence = preprocess_sentence(sentence=sentence)inputs = [q_tokenizer.word_index[i] for i in sentence.split(' ')]inputs = tf.keras.preprocessing.sequence.pad_sequences(sequences=[inputs], maxlen=30, padding='post')inputs = tf.convert_to_tensor(value=inputs)result = ''hidden = [tf.zeros(shape=(1, units))]enc_out, enc_hidden = encoder(inputs, hidden)dec_hidden = enc_hiddendec_input = tf.expand_dims(input=[a_tokenizer.word_index['<start>']], axis=0)for t in range(q_tesor_length):predictions, dec_hidden, attention_weights = decoder(dec_input, dec_hidden, enc_out)predicted_id = tf.argmax(predictions[0]).numpy()result += a_tokenizer.index_word[predicted_id] + ' 'if a_tokenizer.index_word[predicted_id] == '<end>':breakdec_input = tf.expand_dims(input=[predicted_id], axis=0)# print("Q: %s" % sentence[8:-6].replace(' ', ''))# print("A: {}".format(result[:-6].replace(' ', '')))# print("A: {}".format(result.replace(' ', '')))return result[:-6].replace(' ', '')file_path = os.path.dirname(__file__)corpus_path = os.path.join(file_path, 'dataset/corpus.txt')checkpoint_dir = os.path.join(file_path, 'model/train_checkpoints')q_tensor, a_tensor, q_tokenizer, a_tokenizer = load_dataset(file_path=corpus_path)q_tesor_length = max_length(q_tensor)a_tesor_length = max_length(a_tensor)buffer_size = len(q_tensor)batch_size = 32steps_per_epoch = len(q_tensor) // batch_sizeembedding_dim = 128units = 256# q_tokenizer.word_index 字典类型(word, id)vocab_q_size = len(q_tokenizer.word_index) + 1vocab_a_size = len(a_tokenizer.word_index) + 1# 模型初始化encoder = Encoder(vocab_size=vocab_q_size, embedding_dim=embedding_dim, enc_units=units, batch_size=batch_size)attention_layer = BahdanauAttention(units=10)decoder = Decoder(vocab_size=vocab_a_size, embedding_dim=embedding_dim, dec_units=units, batch_size=batch_size)if __name__ == '__main__':input_sentence = "Start chatting..."while input_sentence != "stop":print("请输入:")input_sentence = input()try:predict(input_sentence)print("----------------------")except Exception as err:print('Test model error info: ', err)

4. 测试

  首先要把微信公众号的基本配置改一下,把那个服务器地址更改成阿里云的公网IP,然后启动服务器就可以了(大致需要五六分钟)。
  测试的结果如下:


  目前来看,机器人还很沙雕,毕竟只训练了两个epoch,准备再多训练几次,不过整体来看还蛮好的,部署的流程成功的走了一下,接下来就开始继续训练模型了。
  在阿里云后台看了一下服务器,模型确实比较吃内存,4G内存占用了近80%,怪不得2G内存不够用!


  总的来说,很OK,很nice!!!!想体验的小伙伴们,欢迎来玩哦,关注微信公众号夏悠然

微信公众号上部署自己训练的聊天机器人(腾讯云服务器+TensorFlow2.1+Django3.1)相关推荐

  1. 微信小程序详细图文教程-10分钟完成微信小程序开发部署发布(3元获取腾讯云服务器带小程序支持系统)...

    很多朋友都认为微信小程序申请.部署.发布很难,需要很长时间. 实际上,微信和腾讯云同是腾讯产品,已经提供了10分钟(根据准备资源情况,已完成小程序申请认证)完成小程序开发.部署.发布的方式.当然,实现 ...

  2. 时序分解股票数据并部署在微信公众号上

    目的 将股票价格进行时序分解,得到趋势图.周期图和误差图.然后放到微信公众号上,让用户输入"002581.SZ"等股票代码,即可自动回复以上的图片. 主要思路 用tushare获得 ...

  3. 为什么我不在微信公众号上写文章

    作者: 陈浩 原文: https://coolshell.cn/articles/17391.html 很多朋友问我为什么不在微信公众号上写文章.我都没有直接回答,老实说,我也是扭扭捏捏的,才去开了个 ...

  4. 工具类产品适合在微信公众号上运营吗?

    1. 工具类产品适合在微信公众号上运营吗? 问题描述:如果工具类的产品,主要的核心服务功能搬到微信公众号上运营会有些什么优势和劣势,以及如何在公众号上实现流量变现? 答:微信公众号之前有三种分类:服务 ...

  5. CSDN的文章如何快速转移到微信公众号上

    简单做个介绍,因为需要同时维护CSDN和微信公众号上的文章.所以就涉及到如何不做重复的工作. 所以这里推荐下我个人刚刚发现的比较好用的一个Chrome上用的一个插件,叫"Markdown N ...

  6. 【安信可A9G专题②】A9G在微信公众号上的定位功能笔记分享;

    本系列博客学习由 安信可科技 - 官方博客 技术分享,如有疑问请留言或联系邮箱. 1.A9G环境在windows上搭建并编译,串口打印 Hello GPRS 2.A9G在微信公众号上的定位功能笔记分享 ...

  7. 在微信公众号上显示指定位置的地图

    在微信公众号上显示指定位置的地图 需求:公众号获取用户上报的位置,展示出当前位置的地图 解决方案:通过公众号消息事件存储用户经纬度后,将经纬度作为参数打开腾讯地图一个可以自定义地图标记的url 链接, ...

  8. 封装微信公众号上传照片方法

    1.微信公众号上传照片方法 wxPic.js // 弹出提示消息的组件 import { Toast } from "vant"; // 微信JS-SDK文件,微信开发者官方有 i ...

  9. Java 微信公众号上传永久素材的方法

    用 Java 实现微信公众号上传永久素材,代码如下: /*** 上传其他永久素材(图片素材的上限为5000,其他类型为1000)* @param appid* @param secret* @retu ...

最新文章

  1. 2014西安 H 有向图博弈 UVALive-7042
  2. 内存不足导致mysql关闭,CentOS6.5增加swap分区
  3. 支付系统整体设计:整体架构设计以及注意要点(一)
  4. NS2安装笔记---SUSE Linux
  5. 【CodeForces 577C】Vasya and Petya’s Game
  6. 没有J2EE容器的JNDI和JPA
  7. ARGB和PARGB
  8. 图形学 射线相交算法_计算机图形学中的阴极射线管(CRT)
  9. 本科阶段计算机专业的科学体系,【学习方法】一位大三本科生的计算机科学与技术学习反思录...
  10. 电子计算机发展为第五代,电子计算机的发展历程是怎样的?
  11. 个人Androidstudio快捷键及常用设置配置
  12. win10计算机管理看不见蓝牙,解决win10蓝牙开关不见了的方法
  13. 分布式计算、统计学习与ADMM算法
  14. VS2013使用技巧汇总
  15. VT是什么?怎么打开教程
  16. maven项目的Archetype常用选择
  17. 什么是AP,胖瘦AP如何区分?
  18. ORA-64203: 目标缓冲区太小, 无法容纳字符集转换之后的 CLOB 数据
  19. 智行火车票免费加速到VIP最高速抢票(不用朋友积攒或者购买加速包)
  20. C# ListBox 控件

热门文章

  1. 做了几个Firefox的主题
  2. NoSQL数据库入门概述
  3. 电子行业:万物互联,开启智能新时代
  4. 中国AI第一深度学习平台飞桨再迎一系列升级,百度打造“现代化中央厨房”
  5. MathCast 免费开源 数学公式 演算编辑器
  6. css背景的设置及属性
  7. GoLang之Go语言优点
  8. php什么事阿帕奇,apache到底是什么
  9. 【Python学习】http网站发送请求
  10. php将敏感词替换为*的方法