Tensorflow搭建LSTM对文本进行分类
rnn_model.py:
#!/usr/bin/python
# -*- coding: utf-8 -*-import tensorflow as tfclass TRNNConfig(object):"""RNN配置参数"""# 模型参数embedding_dim = 64 # 词向量维度seq_length = 600 # 序列长度num_classes = 10 # 类别数vocab_size = 5000 # 词汇表达小num_layers= 2 # 隐藏层层数hidden_dim = 128 # 隐藏层神经元rnn = 'gru' # lstm 或 grudropout_keep_prob = 0.8 # dropout保留比例learning_rate = 1e-3 # 学习率batch_size = 128 # 每批训练大小num_epochs = 10 # 总迭代轮次print_per_batch = 100 # 每多少轮输出一次结果save_per_batch = 10 # 每多少轮存入tensorboardclass TextRNN(object):"""文本分类,RNN模型"""def __init__(self, config):self.config = config# 三个待输入的数据self.input_x = tf.placeholder(tf.int32, [None, self.config.seq_length], name='input_x')self.input_y = tf.placeholder(tf.float32, [None, self.config.num_classes], name='input_y')self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')self.rnn()def rnn(self):"""rnn模型"""def lstm_cell(): # lstm核return tf.contrib.rnn.BasicLSTMCell(self.config.hidden_dim, state_is_tuple=True)def gru_cell(): # gru核return tf.contrib.rnn.GRUCell(self.config.hidden_dim)def dropout(): # 为每一个rnn核后面加一个dropout层if (self.config.rnn == 'lstm'):cell = lstm_cell()else:cell = gru_cell()return tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=self.keep_prob)# 词向量映射with tf.device('/cpu:0'):embedding = tf.get_variable('embedding', [self.config.vocab_size, self.config.embedding_dim])embedding_inputs = tf.nn.embedding_lookup(embedding, self.input_x)with tf.name_scope("rnn"):# 多层rnn网络cells = [dropout() for _ in range(self.config.num_layers)] # 定义cellrnn_cell = tf.contrib.rnn.MultiRNNCell(cells, state_is_tuple=True) # 将两层的lstm组装起来#运行LSTM_outputs, _ = tf.nn.dynamic_rnn(cell=rnn_cell, inputs=embedding_inputs, dtype=tf.float32) # _outputs表示最后一层的输出【?,600,128】;"_":表示每一层的最后一个step的输出,也就是2个【?,128】,几层就有几个【?,128】last = _outputs[:, -1, :] # 取最后一个时序输出作为结果with tf.name_scope("score"):# 全连接层,后面接dropout以及relu激活fc = tf.layers.dense(last, self.config.hidden_dim, name='fc1')fc = tf.contrib.layers.dropout(fc, self.keep_prob)fc = tf.nn.relu(fc)# 分类器self.logits = tf.layers.dense(fc, self.config.num_classes, name='fc2')self.y_pred_cls = tf.argmax(tf.nn.softmax(self.logits), 1) # 预测类别with tf.name_scope("optimize"):# 损失函数,交叉熵cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.input_y)self.loss = tf.reduce_mean(cross_entropy)# 优化器self.optim = tf.train.AdamOptimizer(learning_rate=self.config.learning_rate).minimize(self.loss)with tf.name_scope("accuracy"):# 准确率correct_pred = tf.equal(tf.argmax(self.input_y, 1), self.y_pred_cls)self.acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
run_rnn.py:
# coding: utf-8from __future__ import print_functionimport os
import sys
import time
from datetime import timedeltaimport numpy as np
import tensorflow as tf
from sklearn import metricsfrom rnn_model import TRNNConfig, TextRNN
from data.cnews_loader import read_vocab, read_category, batch_iter, process_file, build_vocabbase_dir = 'data/cnews'
train_dir = os.path.join(base_dir, 'cnews.train.txt')
test_dir = os.path.join(base_dir, 'cnews.test.txt')
val_dir = os.path.join(base_dir, 'cnews.val.txt')
vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')save_dir = 'checkpoints/textrnn'
save_path = os.path.join(save_dir, 'best_validation') # 最佳验证结果保存路径def get_time_dif(start_time):"""获取已使用时间"""end_time = time.time()time_dif = end_time - start_timereturn timedelta(seconds=int(round(time_dif)))def feed_data(x_batch, y_batch, keep_prob):feed_dict = {model.input_x: x_batch,model.input_y: y_batch,model.keep_prob: keep_prob}return feed_dictdef evaluate(sess, x_, y_):"""评估在某一数据上的准确率和损失"""data_len = len(x_)batch_eval = batch_iter(x_, y_, 128)total_loss = 0.0total_acc = 0.0for x_batch, y_batch in batch_eval:batch_len = len(x_batch)feed_dict = feed_data(x_batch, y_batch, 1.0)y_pred_class,loss, acc = sess.run([model.y_pred_cls,model.loss, model.acc], feed_dict=feed_dict)total_loss += loss * batch_lentotal_acc += acc * batch_lenreturn y_pred_class,total_loss / data_len, total_acc / data_lendef train():print("Configuring TensorBoard and Saver...")# 配置 Tensorboard,重新训练时,请将tensorboard文件夹删除,不然图会覆盖tensorboard_dir = 'tensorboard/textrnn'if not os.path.exists(tensorboard_dir):os.makedirs(tensorboard_dir)tf.summary.scalar("loss", model.loss)tf.summary.scalar("accuracy", model.acc)merged_summary = tf.summary.merge_all()writer = tf.summary.FileWriter(tensorboard_dir)# 配置 Saversaver = tf.train.Saver()if not os.path.exists(save_dir):os.makedirs(save_dir)print("Loading training and validation data...")# 载入训练集与验证集start_time = time.time()x_train, y_train = process_file(train_dir, word_to_id, cat_to_id, config.seq_length)x_val, y_val = process_file(val_dir, word_to_id, cat_to_id, config.seq_length)time_dif = get_time_dif(start_time)print("Time usage:", time_dif)# 创建sessionsession = tf.Session()session.run(tf.global_variables_initializer())writer.add_graph(session.graph)print('Training and evaluating...')start_time = time.time()total_batch = 0 # 总批次best_acc_val = 0.0 # 最佳验证集准确率last_improved = 0 # 记录上一次提升批次require_improvement = 1000 # 如果超过1000轮未提升,提前结束训练flag = Falsefor epoch in range(config.num_epochs):print('Epoch:', epoch + 1)batch_train = batch_iter(x_train, y_train, config.batch_size)for x_batch, y_batch in batch_train:feed_dict = feed_data(x_batch, y_batch, config.dropout_keep_prob)if total_batch % config.save_per_batch == 0:# 每多少轮次将训练结果写入tensorboard scalars = session.run(merged_summary, feed_dict=feed_dict)writer.add_summary(s, total_batch)if total_batch % config.print_per_batch == 0:# 每多少轮次输出在训练集和验证集上的性能feed_dict[model.keep_prob] = 1.0loss_train, acc_train = session.run([model.loss, model.acc], feed_dict=feed_dict)_,loss_val, acc_val = evaluate(session, x_val, y_val) # todoif acc_val > best_acc_val:# 保存最好结果best_acc_val = acc_vallast_improved = total_batchsaver.save(sess=session, save_path=save_path)improved_str = '*'else:improved_str = ''time_dif = get_time_dif(start_time)msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' \+ ' Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, Time: {5} {6}'print(msg.format(total_batch, loss_train, acc_train, loss_val, acc_val, time_dif, improved_str))session.run(model.optim, feed_dict=feed_dict) # 运行优化total_batch += 1if total_batch - last_improved > require_improvement:# 验证集正确率长期不提升,提前结束训练print("No optimization for a long time, auto-stopping...")flag = Truebreak # 跳出循环if flag: # 同上breakdef test():print("Loading test data...")start_time = time.time()x_test, y_test = process_file(test_dir, word_to_id, cat_to_id, config.seq_length)session = tf.Session()session.run(tf.global_variables_initializer())saver = tf.train.Saver()saver.restore(sess=session, save_path=save_path) # 读取保存的模型print('Testing...')y_pred,loss_test, acc_test = evaluate(session, x_test, y_test)msg = 'Test Loss: {0:>6.2}, Test Acc: {1:>7.2%}'print(msg.format(loss_test, acc_test))batch_size = 128data_len = len(x_test)num_batch = int((data_len - 1) / batch_size) + 1y_test_cls = np.argmax(y_test, 1)y_pred_cls = np.zeros(shape=len(x_test), dtype=np.int32) # 保存预测结果for i in range(num_batch): # 逐批次处理start_id = i * batch_sizeend_id = min((i + 1) * batch_size, data_len)feed_dict = {model.input_x: x_test[start_id:end_id],model.keep_prob: 1.0}y_pred_cls[start_id:end_id] = session.run(model.y_pred_cls, feed_dict=feed_dict)# 评估print("Precision, Recall and F1-Score...")print(metrics.classification_report(y_test_cls, y_pred_cls, target_names=categories))# 混淆矩阵print("Confusion Matrix...")cm = metrics.confusion_matrix(y_test_cls, y_pred_cls)print(cm)time_dif = get_time_dif(start_time)print("Time usage:", time_dif)if __name__ == '__main__':print('Configuring RNN model...')config = TRNNConfig()if not os.path.exists(vocab_dir): # 如果不存在词汇表,重建build_vocab(train_dir, vocab_dir, config.vocab_size)categories, cat_to_id = read_category()words, word_to_id = read_vocab(vocab_dir)config.vocab_size = len(words)model = TextRNN(config)option='train'if option == 'train':train()else:test()
predict_rnn.py:
# coding: utf-8from __future__ import print_functionimport os
import tensorflow as tf
import tensorflow.contrib.keras as krfrom cnn_model import TCNNConfig, TextCNN
from data.cnews_loader import read_category, read_vocab
from rnn_model import TRNNConfig, TextRNNtry:bool(type(unicode))
except NameError:unicode = strbase_dir = 'data/cnews'
vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')save_dir = 'checkpoints/textrnn'
save_path = os.path.join(save_dir, 'best_validation') # 最佳验证结果保存路径class RnnModel:def __init__(self):self.config = TRNNConfig()self.categories, self.cat_to_id = read_category()self.words, self.word_to_id = read_vocab(vocab_dir)self.config.vocab_size = len(self.words)self.model = TextRNN(self.config)self.session = tf.Session()self.session.run(tf.global_variables_initializer())saver = tf.train.Saver()saver.restore(sess=self.session, save_path=save_path) # 读取保存的模型def predict(self, message):# 支持不论在python2还是python3下训练的模型都可以在2或者3的环境下运行content = unicode(message)data = [self.word_to_id[x] for x in content if x in self.word_to_id]feed_dict = {self.model.input_x: kr.preprocessing.sequence.pad_sequences([data], self.config.seq_length),self.model.keep_prob: 1.0}y_pred_cls = self.session.run(self.model.y_pred_cls, feed_dict=feed_dict)return self.categories[y_pred_cls[0]]if __name__ == '__main__':rnn_model = RnnModel()test_rnn_demo = ['三星ST550以全新的拍摄方式超越了以往任何一款数码相机','热火vs骑士前瞻:皇帝回乡二番战 东部次席唾手可得新浪体育讯北京时间3月30日7:00']for i in test_rnn_demo:print(i,":",rnn_model.predict(i))
Tensorflow搭建LSTM对文本进行分类相关推荐
- 对文本进行分类方法python_pytorch实现用CNN和LSTM对文本进行分类方式
model.py: #!/usr/bin/python # -*- coding: utf-8 -*- import torch from torch import nn import numpy a ...
- TensorFlow搭建LSTM实现多变量时间序列预测(负荷预测)
目录 I. 前言 II. 数据处理 III. LSTM模型 IV. 训练/测试 V. 源码及数据 I. 前言 在前面的一篇文章TensorFlow搭建LSTM实现时间序列预测(负荷预测)中,我们利用L ...
- TensorFlow搭建LSTM实现时间序列预测(负荷预测)
目录 I. 前言 II. 数据处理 III. 模型 IV. 训练/测试 V. 源码及数据 I. 前言 前面已经写过不少时间序列预测的文章: 深入理解PyTorch中LSTM的输入和输出(从input输 ...
- pytorch实现用CNN和LSTM对文本进行分类
model.py: #!/usr/bin/python # -*- coding: utf-8 -*-import torch from torch import nn import numpy as ...
- TensorFlow搭建双向LSTM实现时间序列预测(负荷预测)
目录 I. 前言 II. 原理 III. 模型定义 IV. 训练和预测 V. 源码及数据 I. 前言 前面几篇文章中介绍的都是单向LSTM,这篇文章讲一下双向LSTM. 系列文章: 深入理解PyTor ...
- TensorFlow搭建CNN实现时间序列预测(风速预测)
目录 I. 数据集 II. 特征构造 III. 一维卷积 IV. 数据处理 1. 数据预处理 2. 数据集构造 V. CNN模型 1. 模型搭建 2. 模型训练及表现 VI. 源码及数据 时间序列预测 ...
- 基于Keras搭建LSTM网络实现文本情感分类
基于Keras搭建LSTM网络实现文本情感分类 一.语料概况 1.1 数据统计 1.1.1 查看样本均衡情况,对label进行统计 1.1.2 计句子长度及长度出现的频数 1.1.3 绘制句子长度累积 ...
- 用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识
用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识 循环神经网络RNN相比传统的神经网络在处理序列化数据时更有优势,因为RNN能够将加入上(下)文信息进行考虑.一个简单的RNN如 ...
- Tensorflow使用LSTM实现中文文本分类(1)
前言 使用Tensorflow,利用LSTM进行中文文本的分类. 数据集格式如下: ''' 体育 马晓旭意外受伤让国奥警惕 无奈大雨格外青睐殷家军记者傅亚雨沈阳报道 来到沈阳,国奥队依然没有摆脱雨水的 ...
最新文章
- php vs lua,解析LUA与PHP在WEB应用的性能对比
- Uncaught TypeError: Cannot read property ‘events‘ of undefined
- Android ListView不响应OnItemClickListener解决办法
- 43_pytorch nn.Module,模型的创建,构建子模块,API介绍,Sequential(序号),ModuleList,ParameterList,案例等(学习笔记)
- 设定游戏背景和英雄登场
- 研究表明:胸部大小其实早已.....
- qwtqplot用法
- 使用Github Actions构建、发布和部署NuGet软件包
- Jsp+SSH+Mysql实现的校园课程作业网
- Unix网络编程---第三次作业
- 拓端tecdat|R语言特征选择——逐步回归
- ajax如何传两个不同的参数,ajax 如何从后台传多个data对象(多个参数)string类型的...
- Knoll Light Factory 3.2 for mac完整汉化版|灯光工厂 for mac中文版
- Octave与MATLAB
- java超市管理系统ppt_基于java-web的超市管理系统毕业答辩ppt课件
- 加工中心计算机编程自学,自学加工中心编程(简单易学)图文讲解
- Winform自动升级系统的设计与实现(源码)
- 列表页详情页html源码,UI布局欣赏:文章列表与内容详情页设计
- 互联网+废品回收小程序,废品回收小程序,废品回收小程序平台,蚂蚁废收小程序
- Chrome实现独立代理
热门文章
- Altmetrics(替代计量学):对你的论文影响力的评价方法
- erc20钱包下载_以太坊ERC20代币数据集【1000+】
- element el-input 只能输入正整数完美解决不闪动
- windows安装scoop教程
- python加密与解密_Python字符串加密与解密的方法总结
- 试讲计算机领域的知识点,【教编面试】小学计算机与技术试讲模板
- 苹果手机清灰_手机清灰音频
- 微软“免费域名邮箱”Windows Live Custom Domains
- Massive MIMO与MU-MIMO的区别?
- 写给大忙人看的Keil和Proteus联调使用方法