rnn_model.py:
#!/usr/bin/python
# -*- coding: utf-8 -*-import tensorflow as tfclass TRNNConfig(object):"""RNN配置参数"""# 模型参数embedding_dim = 64      # 词向量维度seq_length = 600        # 序列长度num_classes = 10        # 类别数vocab_size = 5000       # 词汇表达小num_layers= 2           # 隐藏层层数hidden_dim = 128        # 隐藏层神经元rnn = 'gru'             # lstm 或 grudropout_keep_prob = 0.8 # dropout保留比例learning_rate = 1e-3    # 学习率batch_size = 128         # 每批训练大小num_epochs = 10          # 总迭代轮次print_per_batch = 100    # 每多少轮输出一次结果save_per_batch = 10      # 每多少轮存入tensorboardclass TextRNN(object):"""文本分类,RNN模型"""def __init__(self, config):self.config = config# 三个待输入的数据self.input_x = tf.placeholder(tf.int32, [None, self.config.seq_length], name='input_x')self.input_y = tf.placeholder(tf.float32, [None, self.config.num_classes], name='input_y')self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')self.rnn()def rnn(self):"""rnn模型"""def lstm_cell():   # lstm核return tf.contrib.rnn.BasicLSTMCell(self.config.hidden_dim, state_is_tuple=True)def gru_cell():  # gru核return tf.contrib.rnn.GRUCell(self.config.hidden_dim)def dropout(): # 为每一个rnn核后面加一个dropout层if (self.config.rnn == 'lstm'):cell = lstm_cell()else:cell = gru_cell()return tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=self.keep_prob)# 词向量映射with tf.device('/cpu:0'):embedding = tf.get_variable('embedding', [self.config.vocab_size, self.config.embedding_dim])embedding_inputs = tf.nn.embedding_lookup(embedding, self.input_x)with tf.name_scope("rnn"):# 多层rnn网络cells = [dropout() for _ in range(self.config.num_layers)]         # 定义cellrnn_cell = tf.contrib.rnn.MultiRNNCell(cells, state_is_tuple=True) # 将两层的lstm组装起来#运行LSTM_outputs, _ = tf.nn.dynamic_rnn(cell=rnn_cell, inputs=embedding_inputs, dtype=tf.float32)  # _outputs表示最后一层的输出【?,600,128】;"_":表示每一层的最后一个step的输出,也就是2个【?,128】,几层就有几个【?,128】last = _outputs[:, -1, :]  # 取最后一个时序输出作为结果with tf.name_scope("score"):# 全连接层,后面接dropout以及relu激活fc = tf.layers.dense(last, self.config.hidden_dim, name='fc1')fc = tf.contrib.layers.dropout(fc, self.keep_prob)fc = tf.nn.relu(fc)# 分类器self.logits = tf.layers.dense(fc, self.config.num_classes, name='fc2')self.y_pred_cls = tf.argmax(tf.nn.softmax(self.logits), 1)  # 预测类别with tf.name_scope("optimize"):# 损失函数,交叉熵cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.input_y)self.loss = tf.reduce_mean(cross_entropy)# 优化器self.optim = tf.train.AdamOptimizer(learning_rate=self.config.learning_rate).minimize(self.loss)with tf.name_scope("accuracy"):# 准确率correct_pred = tf.equal(tf.argmax(self.input_y, 1), self.y_pred_cls)self.acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
run_rnn.py:
# coding: utf-8from __future__ import print_functionimport os
import sys
import time
from datetime import timedeltaimport numpy as np
import tensorflow as tf
from sklearn import metricsfrom rnn_model import TRNNConfig, TextRNN
from data.cnews_loader import read_vocab, read_category, batch_iter, process_file, build_vocabbase_dir = 'data/cnews'
train_dir = os.path.join(base_dir, 'cnews.train.txt')
test_dir = os.path.join(base_dir, 'cnews.test.txt')
val_dir = os.path.join(base_dir, 'cnews.val.txt')
vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')save_dir = 'checkpoints/textrnn'
save_path = os.path.join(save_dir, 'best_validation')  # 最佳验证结果保存路径def get_time_dif(start_time):"""获取已使用时间"""end_time = time.time()time_dif = end_time - start_timereturn timedelta(seconds=int(round(time_dif)))def feed_data(x_batch, y_batch, keep_prob):feed_dict = {model.input_x: x_batch,model.input_y: y_batch,model.keep_prob: keep_prob}return feed_dictdef evaluate(sess, x_, y_):"""评估在某一数据上的准确率和损失"""data_len = len(x_)batch_eval = batch_iter(x_, y_, 128)total_loss = 0.0total_acc = 0.0for x_batch, y_batch in batch_eval:batch_len = len(x_batch)feed_dict = feed_data(x_batch, y_batch, 1.0)y_pred_class,loss, acc = sess.run([model.y_pred_cls,model.loss, model.acc], feed_dict=feed_dict)total_loss += loss * batch_lentotal_acc += acc * batch_lenreturn y_pred_class,total_loss / data_len, total_acc / data_lendef train():print("Configuring TensorBoard and Saver...")# 配置 Tensorboard,重新训练时,请将tensorboard文件夹删除,不然图会覆盖tensorboard_dir = 'tensorboard/textrnn'if not os.path.exists(tensorboard_dir):os.makedirs(tensorboard_dir)tf.summary.scalar("loss", model.loss)tf.summary.scalar("accuracy", model.acc)merged_summary = tf.summary.merge_all()writer = tf.summary.FileWriter(tensorboard_dir)# 配置 Saversaver = tf.train.Saver()if not os.path.exists(save_dir):os.makedirs(save_dir)print("Loading training and validation data...")# 载入训练集与验证集start_time = time.time()x_train, y_train = process_file(train_dir, word_to_id, cat_to_id, config.seq_length)x_val, y_val = process_file(val_dir, word_to_id, cat_to_id, config.seq_length)time_dif = get_time_dif(start_time)print("Time usage:", time_dif)# 创建sessionsession = tf.Session()session.run(tf.global_variables_initializer())writer.add_graph(session.graph)print('Training and evaluating...')start_time = time.time()total_batch = 0  # 总批次best_acc_val = 0.0  # 最佳验证集准确率last_improved = 0  # 记录上一次提升批次require_improvement = 1000  # 如果超过1000轮未提升,提前结束训练flag = Falsefor epoch in range(config.num_epochs):print('Epoch:', epoch + 1)batch_train = batch_iter(x_train, y_train, config.batch_size)for x_batch, y_batch in batch_train:feed_dict = feed_data(x_batch, y_batch, config.dropout_keep_prob)if total_batch % config.save_per_batch == 0:# 每多少轮次将训练结果写入tensorboard scalars = session.run(merged_summary, feed_dict=feed_dict)writer.add_summary(s, total_batch)if total_batch % config.print_per_batch == 0:# 每多少轮次输出在训练集和验证集上的性能feed_dict[model.keep_prob] = 1.0loss_train, acc_train = session.run([model.loss, model.acc], feed_dict=feed_dict)_,loss_val, acc_val = evaluate(session, x_val, y_val)  # todoif acc_val > best_acc_val:# 保存最好结果best_acc_val = acc_vallast_improved = total_batchsaver.save(sess=session, save_path=save_path)improved_str = '*'else:improved_str = ''time_dif = get_time_dif(start_time)msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' \+ ' Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, Time: {5} {6}'print(msg.format(total_batch, loss_train, acc_train, loss_val, acc_val, time_dif, improved_str))session.run(model.optim, feed_dict=feed_dict)  # 运行优化total_batch += 1if total_batch - last_improved > require_improvement:# 验证集正确率长期不提升,提前结束训练print("No optimization for a long time, auto-stopping...")flag = Truebreak  # 跳出循环if flag:  # 同上breakdef test():print("Loading test data...")start_time = time.time()x_test, y_test = process_file(test_dir, word_to_id, cat_to_id, config.seq_length)session = tf.Session()session.run(tf.global_variables_initializer())saver = tf.train.Saver()saver.restore(sess=session, save_path=save_path)  # 读取保存的模型print('Testing...')y_pred,loss_test, acc_test = evaluate(session, x_test, y_test)msg = 'Test Loss: {0:>6.2}, Test Acc: {1:>7.2%}'print(msg.format(loss_test, acc_test))batch_size = 128data_len = len(x_test)num_batch = int((data_len - 1) / batch_size) + 1y_test_cls = np.argmax(y_test, 1)y_pred_cls = np.zeros(shape=len(x_test), dtype=np.int32)  # 保存预测结果for i in range(num_batch):  # 逐批次处理start_id = i * batch_sizeend_id = min((i + 1) * batch_size, data_len)feed_dict = {model.input_x: x_test[start_id:end_id],model.keep_prob: 1.0}y_pred_cls[start_id:end_id] = session.run(model.y_pred_cls, feed_dict=feed_dict)# 评估print("Precision, Recall and F1-Score...")print(metrics.classification_report(y_test_cls, y_pred_cls, target_names=categories))# 混淆矩阵print("Confusion Matrix...")cm = metrics.confusion_matrix(y_test_cls, y_pred_cls)print(cm)time_dif = get_time_dif(start_time)print("Time usage:", time_dif)if __name__ == '__main__':print('Configuring RNN model...')config = TRNNConfig()if not os.path.exists(vocab_dir):  # 如果不存在词汇表,重建build_vocab(train_dir, vocab_dir, config.vocab_size)categories, cat_to_id = read_category()words, word_to_id = read_vocab(vocab_dir)config.vocab_size = len(words)model = TextRNN(config)option='train'if option == 'train':train()else:test()
predict_rnn.py:
# coding: utf-8from __future__ import print_functionimport os
import tensorflow as tf
import tensorflow.contrib.keras as krfrom cnn_model import TCNNConfig, TextCNN
from data.cnews_loader import read_category, read_vocab
from rnn_model import TRNNConfig, TextRNNtry:bool(type(unicode))
except NameError:unicode = strbase_dir = 'data/cnews'
vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')save_dir = 'checkpoints/textrnn'
save_path = os.path.join(save_dir, 'best_validation')  # 最佳验证结果保存路径class RnnModel:def __init__(self):self.config = TRNNConfig()self.categories, self.cat_to_id = read_category()self.words, self.word_to_id = read_vocab(vocab_dir)self.config.vocab_size = len(self.words)self.model = TextRNN(self.config)self.session = tf.Session()self.session.run(tf.global_variables_initializer())saver = tf.train.Saver()saver.restore(sess=self.session, save_path=save_path)  # 读取保存的模型def predict(self, message):# 支持不论在python2还是python3下训练的模型都可以在2或者3的环境下运行content = unicode(message)data = [self.word_to_id[x] for x in content if x in self.word_to_id]feed_dict = {self.model.input_x: kr.preprocessing.sequence.pad_sequences([data], self.config.seq_length),self.model.keep_prob: 1.0}y_pred_cls = self.session.run(self.model.y_pred_cls, feed_dict=feed_dict)return self.categories[y_pred_cls[0]]if __name__ == '__main__':rnn_model = RnnModel()test_rnn_demo = ['三星ST550以全新的拍摄方式超越了以往任何一款数码相机','热火vs骑士前瞻:皇帝回乡二番战 东部次席唾手可得新浪体育讯北京时间3月30日7:00']for i in test_rnn_demo:print(i,":",rnn_model.predict(i))

Tensorflow搭建LSTM对文本进行分类相关推荐

  1. 对文本进行分类方法python_pytorch实现用CNN和LSTM对文本进行分类方式

    model.py: #!/usr/bin/python # -*- coding: utf-8 -*- import torch from torch import nn import numpy a ...

  2. TensorFlow搭建LSTM实现多变量时间序列预测(负荷预测)

    目录 I. 前言 II. 数据处理 III. LSTM模型 IV. 训练/测试 V. 源码及数据 I. 前言 在前面的一篇文章TensorFlow搭建LSTM实现时间序列预测(负荷预测)中,我们利用L ...

  3. TensorFlow搭建LSTM实现时间序列预测(负荷预测)

    目录 I. 前言 II. 数据处理 III. 模型 IV. 训练/测试 V. 源码及数据 I. 前言 前面已经写过不少时间序列预测的文章: 深入理解PyTorch中LSTM的输入和输出(从input输 ...

  4. pytorch实现用CNN和LSTM对文本进行分类

    model.py: #!/usr/bin/python # -*- coding: utf-8 -*-import torch from torch import nn import numpy as ...

  5. TensorFlow搭建双向LSTM实现时间序列预测(负荷预测)

    目录 I. 前言 II. 原理 III. 模型定义 IV. 训练和预测 V. 源码及数据 I. 前言 前面几篇文章中介绍的都是单向LSTM,这篇文章讲一下双向LSTM. 系列文章: 深入理解PyTor ...

  6. TensorFlow搭建CNN实现时间序列预测(风速预测)

    目录 I. 数据集 II. 特征构造 III. 一维卷积 IV. 数据处理 1. 数据预处理 2. 数据集构造 V. CNN模型 1. 模型搭建 2. 模型训练及表现 VI. 源码及数据 时间序列预测 ...

  7. 基于Keras搭建LSTM网络实现文本情感分类

    基于Keras搭建LSTM网络实现文本情感分类 一.语料概况 1.1 数据统计 1.1.1 查看样本均衡情况,对label进行统计 1.1.2 计句子长度及长度出现的频数 1.1.3 绘制句子长度累积 ...

  8. 用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识

    用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识 循环神经网络RNN相比传统的神经网络在处理序列化数据时更有优势,因为RNN能够将加入上(下)文信息进行考虑.一个简单的RNN如 ...

  9. Tensorflow使用LSTM实现中文文本分类(1)

    前言 使用Tensorflow,利用LSTM进行中文文本的分类. 数据集格式如下: ''' 体育 马晓旭意外受伤让国奥警惕 无奈大雨格外青睐殷家军记者傅亚雨沈阳报道 来到沈阳,国奥队依然没有摆脱雨水的 ...

最新文章

  1. php vs lua,解析LUA与PHP在WEB应用的性能对比
  2. Uncaught TypeError: Cannot read property ‘events‘ of undefined
  3. Android ListView不响应OnItemClickListener解决办法
  4. 43_pytorch nn.Module,模型的创建,构建子模块,API介绍,Sequential(序号),ModuleList,ParameterList,案例等(学习笔记)
  5. 设定游戏背景和英雄登场
  6. 研究表明:胸部大小其实早已.....
  7. qwtqplot用法
  8. 使用Github Actions构建、发布和部署NuGet软件包
  9. Jsp+SSH+Mysql实现的校园课程作业网
  10. Unix网络编程---第三次作业
  11. 拓端tecdat|R语言特征选择——逐步回归
  12. ajax如何传两个不同的参数,ajax 如何从后台传多个data对象(多个参数)string类型的...
  13. Knoll Light Factory 3.2 for mac完整汉化版|灯光工厂 for mac中文版
  14. Octave与MATLAB
  15. java超市管理系统ppt_基于java-web的超市管理系统毕业答辩ppt课件
  16. 加工中心计算机编程自学,自学加工中心编程(简单易学)图文讲解
  17. Winform自动升级系统的设计与实现(源码)
  18. 列表页详情页html源码,UI布局欣赏:文章列表与内容详情页设计
  19. 互联网+废品回收小程序,废品回收小程序,废品回收小程序平台,蚂蚁废收小程序
  20. Chrome实现独立代理

热门文章

  1. Altmetrics(替代计量学):对你的论文影响力的评价方法
  2. erc20钱包下载_以太坊ERC20代币数据集【1000+】
  3. element el-input 只能输入正整数完美解决不闪动
  4. windows安装scoop教程
  5. python加密与解密_Python字符串加密与解密的方法总结
  6. 试讲计算机领域的知识点,【教编面试】小学计算机与技术试讲模板
  7. 苹果手机清灰_手机清灰音频
  8. 微软“免费域名邮箱”Windows Live Custom Domains
  9. Massive MIMO与MU-MIMO的区别?
  10. 写给大忙人看的Keil和Proteus联调使用方法