文章目录

  • 原理
    • 网络
    • embedding
    • convolution and pooling
  • 模型图
  • 代码

有个需求,给短文本分类,然后看了下文本分类的算法

  • 传统机器学习算法:分为特征提取、分类两部分
  • 深度学习算法:融合特征提取和分类,fastText、TextCNN、TextRNN、TextRCNN以及最近很火的bert算法,本文主要记录一下TextCNN。

参考
深度学习:TextCNN
TextCNN模型原理和实现

原理

将卷积神经网络CNN应用到文本分类任务,利用多个不同size的kernel来提取句子中的关键信息(类似于多窗口大小的ngram),从而能够更好地捕捉局部相关性。

网络

TextCNN包含四部分:词嵌入、卷积、池化、全连接+softmax,其实结构相比于图像领域简单很多。

  • Embedding:第一层是图中最左边的7乘5的句子矩阵,每行是词向量,维度=5,这个可以类比为图像中的原始像素点。
  • Convolution:然后经过 kernel_sizes=(2,3,4) 的一维卷积层,每个kernel_size 有两个输出 channel。
  • MaxPolling:第三层是一个1-max pooling层,这样不同长度句子经过pooling层之后都能变成定长的表示。
  • FullConnection and Softmax:最后接一层全连接的 softmax 层,输出每个类别的概率。

embedding

将每一个词表征为一个向量。可以采用预训练的模型,也可以随机初始化。训练过程中可以是static或non-static。感觉上,预训练的词向量可以先static再non-static。还有一种multichannel,两个channel通过预训练的词向量,一个为static一个为non-static,在fine-tune时只有一个通道更新参数。

convolution and pooling

输入一个包含s个单词的句子,假设每个单词预训练的词向量为d维,则输入为sxd,将该输入看做一幅图像,卷积提取相邻单词的特征,采用一维卷积,卷积核的宽度为词向量的维度d,则卷积核大小为wxd,卷积计算后特征为(s-h+1)x1xfilter_num,filter_num为卷积核的数目。这里h也是一个超参数可以取{2,3,4,5,…}
池化层:
1-max pooling,即提取feature map最大的值(这里有一个缺点是只取最大值,将位置信息就忽略了),很大程度上减少了模型参数数量。
average-pooling每个维度取均值。
k-max pooling取所有特征值top-k,并保留特征的先后顺序。

Dynamic Pooling之Chunk-MaxPooling。把某个Filter对应的Convolution层的所有特征向量进行分段,切割成若干段后,在每个分段里面各自取得一个最大特征值,比如将某个Filter的特征向量切成3个Chunk,那么就在每个Chunk里面取一个最大值,于是获得3个特征值。因为是先划分Chunk再分别取Max值的,所以保留了比较粗粒度的模糊的位置信息;当然,如果多次出现强特征,则也可以捕获特征强度。至于这个Chunk怎么划分,可以有不同的做法,比如可以事先设定好段落个数,这是一种静态划分Chunk的思路;也可以根据输入的不同动态地划分Chunk间的边界位置,可以称之为动态Chunk-Max方法。Event Extraction via Dynamic Multi-Pooling Convolutional Neural Networks这篇论文提出的是一种ChunkPooling的变体,就是动态Chunk-Max Pooling的思路,实验证明性能有提升。Local Translation Prediction with Global Sentence Representation 这篇论文也用实验证明了静态Chunk-Max性能相对MaxPooling Over Time方法在机器翻译应用中对应用效果有提升。
Dynamic Pooling卷积时如果碰到triggle词,可以标记下不同色,max-pooling时按不同标记划分chunk。

模型图

详看下图,画的真好,一目了然。

代码

以下是基于Keras的代码,代码我还没有细看,先贴上来。

import loggingfrom keras import Input
from keras.layers import Conv1D, MaxPool1D, Dense, Flatten, concatenate, Embedding
from keras.models import Model
from keras.utils import plot_modeldef textcnn(max_sequence_length, max_token_num, embedding_dim, output_dim, model_img_path=None, embedding_matrix=None):""" TextCNN: 1. embedding layers, 2.convolution layer, 3.max-pooling, 4.softmax layer. """x_input = Input(shape=(max_sequence_length,))logging.info("x_input.shape: %s" % str(x_input.shape))  # (?, 60)if embedding_matrix is None:x_emb = Embedding(input_dim=max_token_num, output_dim=embedding_dim, input_length=max_sequence_length)(x_input)else:x_emb = Embedding(input_dim=max_token_num, output_dim=embedding_dim, input_length=max_sequence_length,weights=[embedding_matrix], trainable=True)(x_input)logging.info("x_emb.shape: %s" % str(x_emb.shape))  # (?, 60, 300)pool_output = []kernel_sizes = [2, 3, 4] for kernel_size in kernel_sizes:c = Conv1D(filters=2, kernel_size=kernel_size, strides=1)(x_emb)p = MaxPool1D(pool_size=int(c.shape[1]))(c)pool_output.append(p)logging.info("kernel_size: %s \t c.shape: %s \t p.shape: %s" % (kernel_size, str(c.shape), str(p.shape)))pool_output = concatenate([p for p in pool_output])logging.info("pool_output.shape: %s" % str(pool_output.shape))  # (?, 1, 6)x_flatten = Flatten()(pool_output)  # (?, 6)y = Dense(output_dim, activation='softmax')(x_flatten)  # (?, 2)logging.info("y.shape: %s \n" % str(y.shape))model = Model([x_input], outputs=[y])if model_img_path:plot_model(model, to_file=model_img_path, show_shapes=True, show_layer_names=False)model.summary()return model

训练有以下几点:

  • 数据量较大:可以直接随机初始化embeddings,然后基于语料通过训练模型网络来对embeddings进行更新和学习。
  • 数据量较小:可以利用外部语料来预训练(pre-train)词向量,然后输入到Embedding层,用预训练的词向量矩阵初始化embeddings。(通过设置weights=[embedding_matrix])。
  • 静态(static)方式:训练过程中不再更新embeddings。实质上属于迁移学习,特别是在目标领域数据量比较小的情况下,采用静态的词向量效果也不错。(通过设置trainable=False)
  • 非静态(non-static)方式:在训练过程中对embeddings进行更新和微调(fine tune),能加速收敛。(通过设置trainable=True)

模型如如下:

TextCNN原理、结构、代码相关推荐

  1. TextCnn原理及实践

    原理 paper地址:https://arxiv.org/pdf/1408.5882.pdf 对于文本分类问题,常见的方法无非就是抽取文本的特征,比如使用doc2evc或者LDA模型将文本转换成一个固 ...

  2. 视觉SLAM开源算法ORB-SLAM3 原理与代码解析

    来源:深蓝学院,文稿整理者:何常鑫,审核&修改:刘国庆 本文总结于上交感知与导航研究所科研助理--刘国庆关于[视觉SLAM开源算法ORB-SLAM3 原理与代码解析]的公开课. ORB-SLA ...

  3. DeepLearning tutorial(3)MLP多层感知机原理简介+代码详解

    FROM:http://blog.csdn.net/u012162613/article/details/43221829 @author:wepon @blog:http://blog.csdn.n ...

  4. DeepLearning tutorial(4)CNN卷积神经网络原理简介+代码详解

    FROM: http://blog.csdn.net/u012162613/article/details/43225445 DeepLearning tutorial(4)CNN卷积神经网络原理简介 ...

  5. PHP网站安装程序的原理及代码

    原文:PHP网站安装程序的原理及代码 原理: 其实PHP程序的安装原理无非就是将数据库结构和内容导入到相应的数据库中,从这个过程中重新配置连接数据库的参数和文件,为了保证不被别人恶意使用安装文件,当安 ...

  6. 【深度学习】搞懂 Vision Transformer 原理和代码,看这篇技术综述就够了

    作者丨科技猛兽 编辑丨极市平台 导读 本文对Vision Transformer的原理和代码进行了非常全面详细的解读,一切从Self-attention开始.Transformer的实现和代码以及Tr ...

  7. SAGAN原理及代码(B站详解,很值得一看)

    代码地址:https://github.com/heykeetae/Self-Attention-GAN 视频讲解:SAGAN原理及代码_哔哩哔哩_bilibili 目录 1.背景+整体介绍 2.算法 ...

  8. MLP多层感知机(人工神经网络)原理及代码实现

    一.多层感知机(MLP)原理简介 多层感知机(MLP,Multilayer Perceptron)也叫人工神经网络(ANN,Artificial Neural Network),除了输入输出层,它中间 ...

  9. LoRa SX1278/76驱动原理 附代码

    LoRa SX1278/76驱动原理 附代码 原理解释 LoRa 关键参数说明 前导码: 报头: 显式报头模式: 隐式报头模式: LoRa 调制解调: 扩频因子: 编码率: 信号带宽: 代码说明 SP ...

  10. Pytorch|YOWO原理及代码详解(二)

    Pytorch|YOWO原理及代码详解(二) 本博客上接,Pytorch|YOWO原理及代码详解(一),阅前可看. 1.正式训练 if opt.evaluate:logging('evaluating ...

最新文章

  1. 我用kafka两年踩过的一些非比寻常的坑
  2. MapReduce开发总结
  3. 【NOIP1999】导弹拦截
  4. AndroidStudio中调试时提示waiting for debugger的奇葩解决
  5. ATL学习笔记〔一〕
  6. 使用 Azure WAF 羞辱黑客的智商
  7. 关于引入 js 文件
  8. 简单工厂模式、工厂方法模式与抽象工厂模式的区别(转)
  9. 按需要生成你的网站导航栏
  10. JS正则表达式的分组匹配
  11. 使用 HTTPS 方式登录防火墙USG6000设备
  12. Python:pip下载库后导入Pycharm的方法
  13. 十三、K8s SVC相关操作
  14. 网络时代课堂教学模式整合的探索
  15. imdisk虚拟光驱安装linux,ImDisk Virtual Disk Driver
  16. 2022苏州市小学信息学奥赛T2-汉诺塔
  17. Mysql如何跨库查询数据?
  18. 【算法】leetcode887鸡蛋掉落题之方法二解析
  19. 利用Landsat8数据的不同波段组合监测冰雪分布情况
  20. Containerd客户端工具(CLI)介绍ctr,nerdctl,crictl,podman以及docker

热门文章

  1. mysql using btree_mysql 索引中的USING BTREE有什么用
  2. 关于处理无法保存打印机设置的问题
  3. shell中的 中文和英文 双引号
  4. 用Python找出了删除自己微信的所有人并将他们自动化删除了
  5. 手机不小心把计算机隐藏了怎么恢复出厂设置,手机不小心恢复出厂设置后怎么找回丢失的文件?...
  6. 鸿蒙系统评论简单分析(nlp)
  7. 淘晶驰串口屏_页面事件详解
  8. 暗灰色android代码,Android实现制作灰色图片
  9. html div添加天气,web前端入门到实战:纯CSS写一个动态太阳的天气图标
  10. 苹果漏洞 Siri会泄露你的个人资料