论文地址:https://arxiv.org/abs/1908.10084
论文中文翻译:https://www.cnblogs.com/gczr/p/12874409.html
源码下载:https://github.com/UKPLab/sentence-transformers
相关网站:https://www.sbert.net/

“论文中文翻译”已相当清楚,故本篇不再翻译,只简单介绍SBERT的原理,以及训练和使用中文相似度模型的方法和效果。

原理

挛生网络Siamese network(后简称SBERT),其中Siamese意为“连体人”,即两人共用部分器官。SBERT模型的子网络都使用BERT模型,且两个BERT模型共享参数。当对比A,B两个句子相似度时,它们分别输入BERT网络,输出是两组表征句子的向量,然后计算二者的相似度;利用该原理还可以使用向量聚类,实现无监督学习任务。

挛生网络有很多应用,比如使用图片搜索时,输入照片将其转换成一组向量,和库中的其它图片对比,找到相似度最高(距离最近)的图片;在问答场景中,找到与用户输入文字最相近的标准问题,然后给出相应解答;对各种文本标准化等等。

衡量语义相似度是自然语言处理中的一个重要应用,BERT源码中并未给出相应例程(run_glue.py只是在其示例框架内的简单示例),真实场景使用时需要做大量修改;而SBERT提供了现成的方法解决了相似度问题,并在速度上更有优势,直接使用更方便。

SBERT对Pytorch进行了封装,简单使用该工具时,不仅不需要了解太多BERT API的细节, Pytorch相关方法也不多,下面来看看其具体用法。

配置环境

需要注意的是机器需要能正常配置BERT运行环境,如GPU+CUDA+Pytorch+Transformer匹配版本。

$ pip install sentence_transformers

下载源码

$ git clone https://github.com/UKPLab/sentence-transformers.git

模型预测

在未进行调优(fine-tune)前,使用预训练的通用中文BERT模型也可以达到一定效果,下例是从几个选项中找到与目标最相近的字符串。

from sentence_transformers import SentenceTransformer
import scipy.spatialembedder = SentenceTransformer('bert-base-chinese')
corpus = ['这是一支铅笔','关节置换术','我爱北京天安门',
]
corpus_embeddings = embedder.encode(corpus)
# 待查询的句子
queries = ['心脏手术','中国首都在哪里']
query_embeddings = embedder.encode(queries)
# 对于每个句子,使用余弦相似度查询最接近的n个句子
closest_n = 2
for query, query_embedding in zip(queries, query_embeddings):distances = scipy.spatial.distance.cdist([query_embedding], corpus_embeddings, "cosine")[0]# 按照距离逆序results = zip(range(len(distances)), distances)results = sorted(results, key=lambda x: x[1])print("======================")print("Query:", query)print("Result:Top 5 most similar sentences in corpus:")for idx, distance in results[0:closest_n]:print(corpus[idx].strip(), "(Score: %.4f)" % (1-distance))

训练中文模型

模型训练方法

训练原理:https://www.sbert.net/docs/training/overview.html
训练示例说明:https://www.sbert.net/examples/training/sts/README.html
训练示例代码:examples/training/sts/training_stsbenchmark.py

训练中文模型

把示例中的bert-base-cased换成bert-base-chinese,即可下载和使用中文模型。需要注意的是:中文和英文词库不同,不能将中文模型用于英文数据训练。

下载中文训练数据

下载信贷相关数据,csv数据7M多,约10W条训练数据,可在下例中使用

$ git clone https://github.com/lixuanhng/NLP_related_projects.git
$ ls NLP_related_projects/BERT/Bert_sim/data

代码

from torch.utils.data import DataLoader
import math
from sentence_transformers import SentenceTransformer, LoggingHandler, losses, models, util
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from sentence_transformers.readers import InputExample
import logging
from datetime import datetime
import sys
import os
import pandas as pdmodel_name = 'bert-base-chinese'
train_batch_size = 16
num_epochs = 4
model_save_path = 'test_output'
logging.basicConfig(format='%(asctime)s - %(message)s',datefmt='%Y-%m-%d %H:%M:%S',level=logging.INFO,handlers=[LoggingHandler()])# Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings
word_embedding_model = models.Transformer(model_name)# Apply mean pooling to get one fixed sized sentence vector
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),pooling_mode_mean_tokens=True,pooling_mode_cls_token=False,pooling_mode_max_tokens=False)model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
train_samples = []
dev_samples = []
test_samples = []def load(path):df = pd.read_csv(path)samples = []for idx,item in df.iterrows():samples.append(InputExample(texts=[item['sentence1'], item['sentence2']], label=float(item['label'])))return samplestrain_samples = load('/workspace/exports/git/NLP_related_projects/BERT/Bert_sim/data/train.csv')
test_samples = load('/workspace/exports/git/NLP_related_projects/BERT/Bert_sim/data/test.csv')
dev_samples = load('/workspace/exports/git/NLP_related_projects/BERT/Bert_sim/data/dev.csv')train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size)
train_loss = losses.CosineSimilarityLoss(model=model)
evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name='sts-dev')
warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) #10% of train data for warm-up# Train the model
model.fit(train_objectives=[(train_dataloader, train_loss)],evaluator=evaluator,epochs=num_epochs,evaluation_steps=1000,warmup_steps=warmup_steps,output_path=model_save_path)model = SentenceTransformer(model_save_path)
test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name='sts-test')
test_evaluator(model, output_path=model_save_path)

测试结果

  • 直接使用预训练的英文模型,测试集正确率21%
  • 直接使用预训练的中文模型,测试集正确率30%
  • 使用1000个用例的训练集,4次迭代,测试集正确率51%
  • 使用10000个用例的训练集,4次迭代,测试集正确率68%
  • 使用100000个用例的训练集,4次迭代,测试集正确率71%

一些技巧

除了设置超参数以外,也可通过构造训练数据来优化SBERT网络,比如:构造正例时,把知识“喂”给模型,如将英文缩写与对应中文作为正例对训练模型;构造反例时用容易混淆的句子对训练模型(文字相似但含义不同的句子;之前预测出错的实例,分析其原因,从而构造反例;使用知识构造容易出错的句子对),以替代之前的随机抽取反例。

参考

  • BERT中文实战(文本相似度) https://blog.csdn.net/weixin_37947156/article/details/84877254
  • Bert 文本相似度实战(使用详解) https://zhuanlan.zhihu.com/p/367726571
  • Sentence-BERT: 一种能快速计算句子相似度的孪生网络 https://www.cnblogs.com/gczr/p/12874409.html
  • Sentence-Bert论文笔记 https://zhuanlan.zhihu.com/p/113133510?from_voters_page=true

语义相似度模型SBERT ——一个挛生网络的优美范例相关推荐

  1. DSSM、CNN-DSSM、LSTM-DSSM等深度学习模型在计算语义相似度上的应用+距离运算

    在NLP领域,语义相似度的计算一直是个难题:搜索场景下query和Doc的语义相似度.feeds场景下Doc和Doc的语义相似度.机器翻译场景下A句子和B句子的语义相似度等等.本文通过介绍DSSM.C ...

  2. 文本匹配(语义相似度/行为相关性)技术综述

    NLP 中,文本匹配技术,不像 MT.MRC.QA 等属于 end-to-end 型任务,通常以文本相似度计算.文本相关性计算的形式,在某应用系统中起核心支撑作用,比如搜索引擎.智能问答.知识检索.信 ...

  3. 一文详解文本语义相似度的研究脉络和最新进展

    每天给你送来NLP技术干货! ©作者 | 崔文谦 单位 | 北京邮电大学 研究方向 | 医学自然语言处理 编辑 | PaperWeekly 本文旨在帮大家快速了解文本语义相似度领域的研究脉络和进展,其 ...

  4. 深度学习解决NLP问题:语义相似度计算——DSSM

    tongzhou 转载请注明出处: http://blog.csdn.net/u013074302/article/details/76422551 导语 在NLP领域,语义相似度的计算一直是个难题: ...

  5. FISSA:融合项目相似度模型和自注意网络的时序推荐

    论文学习: FISSA: Fusing Item Similarity Models with Self-Attention Networks for Sequential Recommendatio ...

  6. 概率图模型中的贝叶斯网络

    目录 一.概率图 二.贝叶斯网络 什么是贝叶斯网络? 贝叶斯网络结构怎么构建? 三.概率知识 四.贝叶斯网络知识 网络 条件独立性 结构 六.概率推断 七.案例分析 八.贝叶斯学习 九.Netica ...

  7. AIGC周报|30秒定制一个文生图模型;60美元让AI玩转《我的世界》;手机版“文生图”模型:2秒不到出一张图

    AIGC(AI Generated Content)即人工智能生成内容.近期爆火的 AI 聊天机器人 ChatGPT,以及 Dall·E 2.Stable Diffusion 等文生图模型,都属于 A ...

  8. 基于神经网络模型的文本语义通顺度计算研究-全文复现(还没弄完)

    该硕士学位论文分为两个部分: ①基于依存句法分析的语义通顺度计算方法 ②基于神经网络模型的语义通顺度计算方法 本篇记录摘抄了该论文的核心内容以及实验复现的详细步骤. 在N-gram模型下进行智能批改场 ...

  9. NLP-文本匹配-2016:MaLSTM(ManhaĴan LSTM,孪生神经网络模型)【语句相似度计算:用于文本对比,内容推荐,重复内容判断】【将原本的计算余弦相似度改为一个线性层来计算相似度】

    <MaLSTM原始论文:Siamese Recurrent Architectures for Learning Sentence Similarity> MaLSTM模型(ManhaĴa ...

最新文章

  1. svn command line tag
  2. Python-语句执行
  3. 跨域调用WebApi
  4. Java中的复合设计模式
  5. jdeveloper_适用于JDeveloper 11gR2的Glassfish插件
  6. 【渝粤教育】电大中专建筑材料作业 题库
  7. 挂“洋头”卖奶粉,澳优还要欺骗好久
  8. 网站链接自动化测试原理及工具介绍
  9. FineReport:任意时刻只允许在一个客户端登陆账号的插件
  10. 关于MultiActionController异步请求Ajax,pc端正常,手机端报error错误;此问题一般是通过setInterval,seTimeout,做Ajax轮询时会产生此问题;
  11. Zmodem安装,拖拽的方式通过shell命令界面实现windows和linux之间的文件互传
  12. 【评分卡开发】信用评分模型构建流程
  13. 从哈密尔顿路径谈NP问题
  14. Dedecms QQ一键登录插件
  15. java生成csr_使用Keytool工具生成CSR
  16. 什么是期货、现货?//2021-2-1
  17. 星际文件系统优点和原理
  18. 我的世界模拟大都市里java_我的世界1.7.10模拟大都市整合包
  19. Best Fitting Hyperplanes for Classification(用于分类的最佳拟合超平面)
  20. linux top VIRT RES SHR SWAP DATA内存参数详解

热门文章

  1. mysql parquet_Spark与Apache Parquet
  2. 【DP】探索数字迷塔
  3. 华为:交换机端口汇集
  4. hmailserver怎么搭建php,hMailServer安装使用教程
  5. 火狐 安装 RESTED 插件
  6. 初识Linux操作系统及常用的Linux命令
  7. 收集优质的中文前端博客(不定期更新中)
  8. 51Nod 算法马拉松23
  9. 在springboot项目中配置hive-jdbc的maven依赖时遇到:Could not find artifact org.glassfish:javax.el:pom:3.0.1-b06-S
  10. WSREP has not yet prepared node for application use