对应的tensorflow版本:2.5.0+

textcnn模型如下:

import tensorflow as tfclass ConvMaxPooling1d(tf.keras.layers.Layer):def __init__(self, filters, kernel):super(ConvMaxPooling1d, self).__init__()self.kernel_size = kernel#(batch_size, step, embedding_size)->(batch_size,step-kernel_size+1,filter_size)self.conv = tf.keras.layers.Conv1D(filters=filters, kernel_size=kernel, activation='relu')# (batch_size,step-kernel_size+1,filter_size)->(batch_size,filter_size)self.pool = tf.keras.layers.GlobalMaxPool1D()tf.random.uniform()def call(self, inputs, masks=None):conv_out = self.conv(inputs)pool_out = self.pool(conv_out)return pool_outclass TextCNN(tf.keras.models.Model):def __init__(self, vocab, embedding_size, hidden_size, filters_list=[50 ,60, 70, 80], kernels=[2,3, 4, 5],dropout=0.5, sentence_length=20):super(TextCNN, self).__init__()ind = tf.feature_column.categorical_column_with_vocabulary_file("sentence_vocab", vocabulary_file=vocab,default_value=0)self.embedding_size = embedding_sizeself.sentence_length = sentence_lengthself.dense_feature_layer = tf.keras.layers.DenseFeatures([tf.feature_column.embedding_column(ind, dimension=embedding_size)])self.conv_maxs = [ConvMaxPooling1d(f, k) for f, k in zip(filters_list, kernels)]self.dropout = tf.keras.layers.Dropout(dropout)self.dense = tf.keras.layers.Dense(hidden_size, activation='relu')self.classifier = tf.keras.layers.Dense(1, activation='sigmoid')# @tf.function(input_signature=(tf.TensorSpec(shape=(None, None), dtype=tf.dtypes.string),))def call(self, inputs):# ***************word token embedding begin***************inputs = tf.convert_to_tensor(inputs)inputs_tensor = tf.reshape(inputs, (-1, 1))embed_word_vectors1 = self.dense_feature_layer({"sentence_vocab": inputs_tensor})embeddings = tf.reshape(embed_word_vectors1, (-1, self.sentence_length, self.embedding_size))# ***************word token embedding end***************#对于每一个layer来说,输入是:(batch_size,step,embedding_size)->(batch_size,step-kernel_size+1,filter_size)conv_outs = [layer(embeddings, None) for layer in self.conv_maxs]# 对于每一个layer来说,输入是:[(batch_size,step-kernel_size+1,filter_size)]->(batch_size,step-kernel_size+1,sum(filter_size))concat_out = tf.concat(conv_outs, axis=-1)dense_out = self.dense(concat_out)drop_out = self.dropout(dense_out)logits = self.classifier(drop_out)return logits

模型训练代码如下:

import tensorflow as tffrom model.TextCNN import TextCNN
from utils.path_utils import get_full_path
from utils.read_batch_data import get_dataclass SingleRNNModelTest:def __init__(self,epoch=3,batch_size=100,embedding_size=256,learning_rate=0.001,model_path="model_version",sentence_vocb_length=20,fill_vocab='TTTTTT',vocab_file_path="data/vocab_clean.txt"):self.optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)self.loss = tf.keras.losses.BinaryCrossentropy(reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE)self.epoch = epochself.batch_size = batch_sizeself.model_path = model_pathself.sentence_vocb_length = sentence_vocb_lengthself.fill_vocab = fill_vocabself.model = TextCNN(vocab=get_full_path(vocab_file_path), embedding_size=embedding_size, hidden_size=20,sentence_length=sentence_vocb_length)self.summary_writer = tf.summary.create_file_writer('./tensorboard/news_label_model/{}'.format(model_path))def train(self):# ========================== Create dataset =======================train_x,train_y = get_data("data/train_data/prepare/train_data_v2.txt", self.sentence_vocb_length, self.fill_vocab)self.model(train_x)board = tf.keras.callbacks.TensorBoard(log_dir=get_full_path("data/fit_log/graph"), write_graph=True)model_save = tf.keras.callbacks.ModelCheckpoint(get_full_path("data/fit_log/fit_model3"), monitor="val_loss",mode="min")self.model.compile(optimizer=self.optimizer, loss=self.loss,metrics=[tf.keras.metrics.BinaryAccuracy(threshold=0.5),tf.keras.metrics.AUC(curve='PR', name='p-r'),tf.keras.metrics.AUC(curve='ROC', name='ROC'),])self.model.fit(x=train_x,y=train_y,batch_size=self.batch_size, epochs=self.epoch,shuffle=True,callbacks=[board, model_save])if __name__ == '__main__':# =============================== GPU ==============================gpu = tf.config.experimental.list_physical_devices(device_type='GPU')print("gpu message:{}".format(gpu))# If you have GPU, and the value is GPU serial number.import osos.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"os.environ["CUDA_VISIBLE_DEVICES"] = "0"epoch = 20batch_size = 1000sentence_vocb_length = 25embedding_size = 216learning_rate = 0.001train_instance = SingleRNNModelTest(epoch=epoch, batch_size=batch_size, sentence_vocb_length=sentence_vocb_length,embedding_size=embedding_size, learning_rate=learning_rate)train_instance.train()

textcnn模型实践相关推荐

  1. [NLP] TextCNN模型原理和实现

    1. 模型原理 1.1 论文 Yoon Kim在论文(2014 EMNLP) Convolutional Neural Networks for Sentence Classification提出Te ...

  2. [NLP] 文本分类之TextCNN模型原理和实现(超详细)

    1. 模型原理 1.1论文 Yoon Kim在论文(2014 EMNLP) Convolutional Neural Networks for Sentence Classification提出Tex ...

  3. 【Pytorch神经网络实战案例】40 TextCNN模型分析IMDB数据集评论的积极与消极

    卷积神经网络不仅在图像视觉领域有很好的效果,而且在基于文本的NLP领域也有很好的效果.TextCN如模型是卷积神经网络用于文本处理方面的一个模型. 在TextCNN模型中,通过多分支卷积技术实现对文本 ...

  4. 基于Text-CNN模型的中文文本分类实战

    七月 上海 | 高性能计算之GPU CUDA培训 7月27-29日三天密集式学习  快速带你入门阅读全文> 正文共5260个字,21张图,预计阅读时间28分钟. Text-CNN 1.文本分类 ...

  5. PySpark︱pyspark.ml 相关模型实践

    文章目录 1 pyspark.ml MLP模型实践 模型存储与加载 9 spark.ml模型评估 MulticlassClassificationEvaluator 1 pyspark.ml MLP模 ...

  6. 1w字详解 ClickHouse漏斗模型实践方案(收藏)

    作者:互联网大数据团队- Wu Yonggang 日常工作中做为数仓开发工程师.数据分析师经常碰到漏斗分析模型,本文详细介绍漏斗模型的概念及基本原理,并阐述了其在平台内部的具体实现.针对实际使用过程的 ...

  7. textcnn文本词向量_基于Text-CNN模型的中文文本分类实战

    1 文本分类 文本分类是自然语言处理领域最活跃的研究方向之一,目前文本分类在工业界的应用场景非常普遍,从新闻的分类.商品评论信息的情感分类到微博信息打标签辅助推荐系统,了解文本分类技术是NLP初学者比 ...

  8. 模型实践 | 高精地图构建模型HDMapNet助力更精准的自动驾驶

    实验 | Freja   算力支持 | 幻方AIHPC 高精地图是自动驾驶系统的关键模块,可以有效提升自动驾驶汽车的行驶安全度,强化自动驾驶系统的整体感知能力和决策能力.然而传统的高精地图构建流程复杂 ...

  9. CVPR2022: Oriented RepPoints论文模型实践(用dota数据集)

    CVPR2022: Oriented RepPoints论文模型实践(用dota数据集) 论文:https://arxiv.org/abs/2105.11111 github:https://gith ...

  10. 声音识别入门经典模型实践-基于大数据训练CNN14网络实现食物咀嚼声音识别

    声音识别入门经典模型实践-基于大数据训练CNN14网络实现食物咀嚼声音识别 项目简介 声音分类是指可以定制识别出当前音频是哪种声音,或者是什么状态/场景的声音.通过声音,人的大脑会获取到大量的信息,其 ...

最新文章

  1. Struts2的拦截器只允许有权限用户访问action
  2. 版本控制8(译文) -(完)
  3. show open tables命令 mysql查看哪些表加锁了
  4. python怎么爬虎牙_使用python爬虫框架scrapy抓取虎牙主播数据
  5. axios 发 post 请求,后端接收不到参数的解决方案
  6. orange实现逻辑回归_逻辑回归模型
  7. 【图像处理】彩色图像处理(Color Image Processing)
  8. 汇编常用DOS命令调用
  9. linux中文件颜色,蓝色,白色等各代表含义
  10. 前端性能优化(四)——网页加载更快的N种方式
  11. 2020电信校园卡已经发售,更新校园卡最新消息及选购建议
  12. 中央财经大学C语言考研真题答案,2017年中央财经大学信息学院901C语言程序设计考研题库...
  13. 黑鲨会升级鸿蒙吗,黑鲨4首批用户评价已出炉,不吹不黑,优缺点都很明显!...
  14. 4.默认参数,不定参数,扩展参数
  15. Linux SPI 子系统(x86平台)
  16. 高德地图看各省分界线_高德地图定位城市区域
  17. 视频怎么做gif表情包?教你一个快速生成的方法
  18. 蓝桥杯试题 算法训练 幂方分解
  19. php记录搜索关键字_PHP记录搜索引擎来路以及搜索输入的关键字
  20. 你咪当我lu lu喎!

热门文章

  1. ECTOUCH系统默认模板是有显示销量的,但是销量一直为0,第二种方法OK
  2. Socket编程模型之完成端口模型
  3. 半正定矩阵的对角元素不小于该矩阵的最小特征值
  4. Badboy下载安装超详细教程
  5. jQuery活动倒计时插件
  6. 全站仪数据导入电脑_怎么把全站仪的数据导到电脑上来,并且成图?
  7. Thinkcell入门与使用
  8. 联想微型计算机Q150,联想Q150E电脑安装攻略
  9. 安卓吃鸡玩家专属:教你电脑玩刺激战场匹配手机最简单的方式
  10. java reader 组合_Java IText 拼接合并PDF的三种方法