目录

前言

模型结构

模型原理

模型训练

特别说明

模型效果

参考

前言

论文全名:SimCSE: Simple Contrastive Learning of Sentence Embeddings

论文地址:https://aclanthology.org/2021.emnlp-main.552.pdf

论文收录于EMNLP2021。

论文提供的代码(pytorch):GitHub - princeton-nlp/SimCSE: EMNLP'2021: SimCSE: Simple Contrastive Learning of Sentence Embeddings

苏神的代码(bert4keras):GitHub - bojone/SimCSE: SimCSE在中文任务上的简单实验

模型结构

SimCSE模型是一种简单的对比句向量表征的框架,包含无监督和有监督两种方法。

无监督学习:会采用Dropout技术,对原始文本进行数据增强,构造出正样本,用于对比学习训练;

监督学习:由于本身有正样本(相近样本),故无需使用Dropout技术,直接训练即可。

模型结构如下:

模型原理

SimCSE模型的核心是对比学习,对比学习是通过拉近相似数据的距离,拉远不相似数据的距离为目标,更好地学习数据的表征。使得其在文本匹配任务中产生更好的效果。

论文中,在一个batch中,样本 i 的训练目标如下:

\LARGE \ell_{i}=-\log \frac{e^{\operatorname{sim}\left(\mathbf{h}_{i}^{z_{i}}, \mathbf{h}_{i}^{z_{i}^{\prime}}\right) / \tau}}{\sum_{j=1}^{N} e^{\operatorname{sim}\left(\mathbf{h}_{i}^{z_{i}}, \mathbf{h}_{j}^{z_{j}^{\prime}}\right) / \tau}}

\tau 是一个控制 softmax 分布的一个超参数,通常设为0.05,(\tau 越大,分布越平滑;\tau 越小,正负样本差距就越大),可以加快收敛。

N 为一个batch的大小,sim 函数表示的是余弦相似度计算,{z_{i}^{\prime}} 表示 {z_{i}} 的增强样本(不同dropout得到的)。

但上面的训练目标的分母仅仅是对原样本和所有的增强样本进行了累加,但没有将一个batch中不同的原样本进行累加,因此还是比较喜欢苏神提供的训练目标,如下:

\huge -\sum_{i=1}^{N} \sum_{\alpha=0,1} \log \frac{e^{\cos \left(\boldsymbol{h}_{i}^{(\alpha)}, \boldsymbol{h}_{i}^{(1-\alpha)}\right) / \tau}}{\sum_{j=1, j \neq i}^{N} e^{\cos \left(\boldsymbol{h}_{i}^{(\alpha)}, \boldsymbol{h}_{j}^{(\alpha)}\right) / \tau}+\sum_{j}^{N} e^{\cos \left(\boldsymbol{h}_{i}^{(\alpha)}, \boldsymbol{h}_{j}^{(1-\alpha)}\right) / \tau}}

再来一个图就更清楚了,下图中 a 表示的是原句子,p 表示的是增强后的句子。

上图红色区域是不计算的部分,因为自己和自己计算相似度是没有意义的。

模型训练

无监督训练

无监督训练过程,究竟是如何对句子进行dropout的呢?

首先,将原句子直接复制一份,得到 sent_a 与 sent_b输入到 Bert 中,得到pool_output,然后通过下面代码将形状为torch.Size ([16, 768])的 pooler_output 被重新整形为torch.Size([ 8, 2, 768]),其中8是batch_size。

pooler_output = pooler_output.view((batch_size, args.num_sent, pooler_output.size(-1)))

原论文提供的代码适合英文数据集,如果是要做中文数据集任务,可以使用苏神提供的代码。

要使用SimCSE模型的话,可以直接加载预训练好的模型,如'BERT', 'RoBERTa', 'WoBERT', 'RoFormer', 'BERT-large', 'RoBERTa-large', 'SimBERT', 'SimBERT-tiny', 'SimBERT-small' 等等,具体效果可以看苏神博客的效果对比;

然后得到数据对应预训练模型的encoder向量表示,注意苏神虽然用的是有监督的数据,只用了句子,没有用标签,还是无监督的。

经过SimCSE模型无监督训练后,得到对应的embedding表示。

有监督训练

论文中,有监督学习是三个句子为一组(x, x+, x-),其中 x+ 作为正样本,x- 与其他句子的x +, x-作为负样本。

特别说明

在使用无监督训练时,dropout 通常取较小的数,如0.1,0.05等,也可以参考苏神取的0.3。

随机选了1万条任务数据训练,效果就很好,不一定需要使用所有的数据。(随机选取的样本量也可以作为一个参数来进行调整,例如选取8000 ~ 12000不等的数据)

batch_size取64(或128),学习率取1e5,供参考。

模型效果

关于SimCSE模型的对比实验

该实验选用SNLI和STS-B数据,对比了有监督和无监督、4种不同的预训练模型、4种不同的向量表示,进行了共计 32次训练(2 * 4 * 4)。

预训练模型如下:

  • BERT

  • BERT-wwm-ext

  • RoBERTa-wwm-ext

  • SimBert

4种不同的向量表示如下:

cls:取 output 最后一层hidden_state第0个位置的hidden,也就是CLS的hidden

pooler:pooler表示的是对[CLS]过了一层nn.Linear层,又过了tanh激活函数,得到的hidden

last-avg:取 output 最后一层hidden_state,先进行位置变换,然后对最后一个维度进行平均池化

first-last-avg:取 output 的第一层和最后一层hidden_state,先分别进行位置变换,然后分别对最后一个维度进行平均池化,池化后进行拼接,拼接后再进行一次池化。

测评指标为spearman相关系数

有监督对比实验

训练集数据为SNLI,测试集和验证集数据为STS-B。

参数设置:batch_size=64,lr=1e-5,droupout_rate=0.3

(设置100个 batch 作为早停)

模型 向量表示 STS-B dev/test/sample
BERT cls 0.8017/0.7589/25600
pooler 0.7734/0.7272/43520
last-avg 0.8004/0.7521/11520
first-last-avg 0.7985/0.7577/26240
BERT-wwm-ext cls 0.8088/0.7608/11520
pooler 0.7714/0.7193/20480
last-avg 0.8087/0.7690/26240
first-last-avg 0.8064/0.7580/27520
RoBERTa-wwm-ext cls 0.8073/0.7693/27520
pooler 0.7755/0.7296/28160
last-avg 0.8047/0.7675/26240
first-last-avg 0.8031/0.7650/46080
SimBert cls 0.8173/0.7675/15360
pooler 0.8148/0.7574/4480
last-avg 0.8154/0.7630/15360
first-last-avg 0.8117/0.7582/15360

无监督对比实验

STS-B 数据的训练数据为SNLI 的'origin' + STS-B的第一句话,测试集和验证集为 STS-B数据。

参数设置:batch_size=64,lr=1e-5,droupout_rate=0.3,pooling=cls, 随机抽样100000样本

设置100个batch早停。

模型 向量表示 STS-B dev/test/sample
BERT cls 0.7324/0.6776/24320
pooler 0.6331/0.5797/39040
last-avg 0.7272/0.6769/7680
first-last-avg 0.7136/0.6707/4480
BERT-wwm-ext cls 0.7260/0.6683/7680
pooler 0.6395/0.5864/1280
last-avg 0.7270/0.6693/4480
first-last-avg 0.7056/0.6540/7680
RoBERTa-wwm-ext cls 0.7552/0.7139/640
pooler 0.6840/0.6549/640
last-avg 0.7140/0.6641/1920
first-last-avg 0.6988/0.6522/2560
SimBert cls 0.7930/0.7278/640
pooler 0.7868/0.7208/640
last-avg 0.7739/0.7155/5760
first-last-avg 0.7597/0.7056/6400

由上面数据看出,一般CLS效果要比其他三种向量表示方法好,对于BERT-wwm-ext模型,last-avg的效果更好一些。有监督下,RoBERTa-wwm-ext模型效果更好,无监督下,SimBert模型效果更好。(当然,这只是对当前数据集下的情况,对于不同数据集,效果可能不一样)。

参考

苏神博客:中文任务还是SOTA吗?我们给SimCSE补充了一些实验 - 科学空间|Scientific Spaceshttps://spaces.ac.cn/archives/8348

刘聪大佬:SimCSE论文精读 - 知乎「句向量表征技术」一直都是NLP领域的热门话题,在BERT前时代,一般都采用word2vec训练出的word-embedding结合pooling策略进行句向量表征,或者在有训练数据情况下,采用TextCNN/BiLSTM结合Siamese network策略进…https://zhuanlan.zhihu.com/p/452761704

特别推荐一篇文章:在 Pytorch 中为无监督方法实现 SimCSE,描述的很详细。

https://bhuvana-kundumani.medium.com/implementation-of-simcse-for-unsupervised-approach-in-pytorch-a3f8da756839https://bhuvana-kundumani.medium.com/implementation-of-simcse-for-unsupervised-approach-in-pytorch-a3f8da756839GitHub - bhuvanakundumani/SimCSE_unsupervisedContribute to bhuvanakundumani/SimCSE_unsupervised development by creating an account on GitHub.https://github.com/bhuvanakundumani/SimCSE_unsupervised

simcse损失函数源码解读:SimCSE的loss实现源码解读 - 知乎

文本匹配之SimCSE模型相关推荐

  1. 文本匹配与ESIM模型详解

    ESIM(Enhanced Sequential Inference Model)是一个综合应用了BiLSTM和注意力机制的模型,在文本匹配中效果十分强大,也是目前为止我见过结构最复杂的模型,下面将会 ...

  2. 【文本匹配】ESIM模型

    ESIM实现 ESIM模型训练包含以下模块: 数据处理加载模块 模型实现模型 pytorch_lightning 封装训练模块 模型训练和使用模块 相关源码可以参见我Github上的源码.下面主要说明 ...

  3. nc65语义模型设计_文本匹配方法系列––多维度语义交互匹配模型

    摘要 本文基于接着多语义匹配模型[1]和BERT匹配模型[2]介绍一些多维度语义交互匹配模型,包括2017 BiMPM模型[3]和腾讯出品的2018 MIX[4].这些方法的核心特征都是在多语义网络的 ...

  4. 文本匹配模型ESIM

    ESIM是一个综合应用了BiLSTM和注意力机制的模型,在文本匹配中效果十分强大. 文本匹配说就是分析两个句子是否具有某种关系,比如有一个问题,现在给出一个答案,我们就需要分析这个答案是否匹配这个问题 ...

  5. 【NLP】深度文本匹配综述

    目  录 1.研究背景与意义  2.深度学习在自然语言处理的应用  3.深度文本匹配与传统文本匹配  4.深度文本匹配国内外研究现状  4.1基于单语义表达的文本匹配 4.2基于多语义表达的文本匹配 ...

  6. nmt模型源文本词项序列_「自然语言处理(NLP)」阿里团队--文本匹配模型(含源码)...

    来源:AINLPer微信公众号 编辑: ShuYini 校稿: ShuYini 时间: 2019-8-14 引言 两篇文章与大家分享,第一篇作者对通用文本匹配模型进行探索,研究了构建一个快速优良的文本 ...

  7. antd 文本域超长问题_「自然语言处理(NLP)」阿里团队--文本匹配模型(含源码)...

    来源:AINLPer微信公众号 编辑: ShuYini 校稿: ShuYini 时间: 2019-8-14 引言     两篇文章与大家分享,第一篇作者对通用文本匹配模型进行探索,研究了构建一个快速优 ...

  8. laravel 分词搜索匹配度_DSSM文本匹配模型在苏宁商品语义召回上的应用

    文本匹配是自然语言处理中的一个核心问题,它不同于MT.MRC.QA 等end-to-end型任务,一般是以文本相似度计算的形式在应用系统中起核心支撑作用1.它可以应用于各种类型的自然语言处理任务中,例 ...

  9. 文本匹配开山之作--双塔模型及实战

    作者 | 夜小白 整理 | NewBeeNLP 在前面一篇文章中,总结了Representation-Based文本匹配模型的改进方法, 基于表征(Representation)的文本匹配.信息检索. ...

最新文章

  1. 关于征集2020重大科学问题和工程技术难题的通知
  2. h2 mysql 兼容性_H2内存数据库对sql语句的支持问题 sql放到mysql数据库中能跑
  3. php解决高并发问题
  4. 学python需要记笔记吗_自学python需要做什么笔记
  5. SpringCloud(第 054 篇)简单 Quartz-Cluster 微服务,采用注解配置 Quartz 分布式集群...
  6. Vue学习(slot、axios)-学习笔记
  7. css hack技术整理
  8. 35 MM配置-采购-采购订单-设置价格差异的容差限制
  9. mysql join 联合查询,MySQL连接(join)查询
  10. 微软邮件服务器名称,邮箱服务器角色概述
  11. 描述最常用的5种http方法的用途_RESTful API系列之HTTP基础
  12. c语言获取栈可用大小,[求助]求教各位大神如何获得C语言函数体的大小?
  13. MATLAB绘制微分方程的相图/方向场/向量场
  14. C语言RSA大数运算库,[转载]RSA大数运算库  c++实现
  15. 已知 char w; int x; float y; double z;,则表达是 w*x+z-y 结果的类型是
  16. 免流服务器系统怎么选,免流云服务器选
  17. PPT画图软件,强烈推荐!提升能力的利器。
  18. Ekl去记录nginx的日志
  19. linux pam 使用例子,PAM认证模块使用实例
  20. python手势识别隐马尔可夫模型_手势识别身份认证的连续隐马尔可夫模型

热门文章

  1. mysql 身份证_MySQL--隐藏手机号、身份证号三种方式
  2. 利用ps制作熊猫表情包【无图】
  3. 本机文件怎么拉到服务器,本机文件怎么传到云服务器上
  4. 更新无限火力的服务器,LOL无限火力延长时间公告 2020无限乱斗火力延长到几号?...
  5. 完美解决Error: Running Homebrew as root is extremely dangerous and no longer supported.
  6. 基于opencv-python的签名抠图程序
  7. 计算机操作系统 精选模拟试题及答案
  8. 如何判断你是否有自主创业的条件
  9. C++ 字符串 string 截取 substr
  10. 8种IO口模式的配置(复制粘贴的,供自己学习的)