单句文本分类是最常见的自然语言处理任务,需要将输入文本分成不同类别。例如:在情感分类任务SST-2中,需要将影评文本输入文本分类模型中,并将其分成褒义或贬义。

1. 建模方法

应用BERT处理单句文本分类任务的模型由输入层、BERT编码层和分类输出层构成。处理过程如下图所示(图源李宏毅老师课件):

  1. 首先在句子的开头加一个代表分类的符号[CLS]
  2. 然后将该位置的output输入到Linear Classifier,进行predict,输出一个分类。

注:整个过程中 Linear Classifier 的参数是需要从头开始学习的,而 BERT 中的参数微调就可以了。

为什么要用[CLS]来进行分类?

因为 BERT 内部是 Transformer,而 Transformer 内部又是 Self-Attention,所以[CLS]的output肯定含有整句话的完整信息。但是Self-Attention计算的向量,自己本身和自己的值肯定是最相关的。现在假设使用w1w_1w1​的output做分类,这那么这个output实际上会更加看重w1w_1w1​,而w1w_1w1​又是一个有实际含义的字或者词,这样难免会影响到最终的结果。但是[CLS]是没有任何意义的占位符,所以就算[CLS]的 output 中自己的值占大头也无所谓.

2. 代码实现

接下来结合实际代码,介绍BERT在单句文本分类任务中的训练方法。这里以英文情感二分类数据集SST-2为例介绍。

这里主要应用了由HuggingFace开发的transformers包和datasets库进行建模,可以极大地简化数据处理和模型建模过程。

  1. 导入包和加载训练数据、分词器、预训练模型和评价方法
import numpy as np
from datasets import load_dataset, load_metric
from transformers import BertTokenizerFast, BertForSequenceClassification,TrainingArguments,Trainerdataset = load_dataset('glue', 'sst2')
tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')
model = BertForSequenceClassification.from_pretrained('bert-base-cased', return_dict = True)
metric = load_metric('glue', 'sst2')
  1. 对训练集分词
def tokenize(examples):return tokenizer(examples['sentence'], truncation=True, padding='max_length')dataset = dataset.map(tokenize, batched=True)
encoded_dataset = dataset.map(lambda examples:{'labels':examples['label']}, batched=True)
  1. 将数据集转化为torch.Tensor类型以训练PyTorch模型
columns = ['input_ids', 'token_type_ids', 'attention_mask', 'labels']
encoded_dataset.set_format(type='torch', columns=columns)
  1. 定义评价指标
def compute_metrics(eval_pred):predictions, labels = eval_predreturn metric.compute(predictions=np.argmax(predictions, axis=1), references=labels)
  1. 定义训练参数TrainingArguments,默认使用AdamW优化器
args = TrainingArguments('ft-sst2',evaluation_strategy='epoch',learning_rate=2e-5,per_device_train_batch_size=4,per_device_eval_batch_size=4,num_train_epochs=2
)
  1. 定义Trainer,指定模型和训练参数,输入训练集、验证集、分词器和评价函数
trainer = Trainer(model,args,train_dataset =encoded_dataset["train"],eval_dataset = encoded_dataset["validation"],tokenizer = tokenizer, compute_metrics = compute_metrics
  1. 进行训练
trainer.train()
  1. 训练完毕后,开始测试
trainer.evaluate()
结果:
{'eval_loss': 0.4584292471408844,'eval_accuracy': 0.9162844036697247,'eval_runtime': 25.5729,'eval_samples_per_second': 34.099,'epoch': 2.0,'eval_mem_cpu_alloc_delta': 215077,'eval_mem_gpu_alloc_delta': 0,'eval_mem_cpu_peaked_delta': 270242,'eval_mem_gpu_peaked_delta': 144781312}

参考资料

  1. 自然语言处理:基于预训练模型的方法
  2. 李宏毅-ELMO, BERT, GPT讲解

BERT微调之单句文本分类相关推荐

  1. BERT微调做中文文本分类

    BERT模型在NLP各项任务中大杀四方,那么我们如何使用这一利器来为我们日常的NLP任务来服务呢?我们首先介绍使用BERT做文本分类任务. 重写读取数据的类 需要根据文件格式重写读取数据的类,只要能够 ...

  2. bert 是单标签还是多标签 的分类_搞定NLP领域的“变形金刚”!教你用BERT进行多标签文本分类...

    大数据文摘出品 来源:medium 编译:李雷.睡不着的iris.Aileen 过去的一年,深度神经网络的应用开启了自然语言处理的新时代.预训练模型在研究领域的应用已经令许多NLP项目的最新成果产生了 ...

  3. bert 是单标签还是多标签 的分类_搞定NLP领域的“变形金刚”!手把手教你用BERT进行多标签文本分类...

    大数据文摘出品 来源:medium 编译:李雷.睡不着的iris.Aileen 过去的一年,深度神经网络的应用开启了自然语言处理的新时代.预训练模型在研究领域的应用已经令许多NLP项目的最新成果产生了 ...

  4. Pytorch——BERT 预训练模型及文本分类(情感分类)

    BERT 预训练模型及文本分类 介绍 如果你关注自然语言处理技术的发展,那你一定听说过 BERT,它的诞生对自然语言处理领域具有着里程碑式的意义.本次试验将介绍 BERT 的模型结构,以及将其应用于文 ...

  5. Bert模型做多标签文本分类

    Bert模型做多标签文本分类 参考链接 BERT模型的详细介绍 图解BERT模型:从零开始构建BERT (强推)李宏毅2021春机器学习课程 我们现在来说,怎么把Bert应用到多标签文本分类的问题上. ...

  6. 【NLP】BERT 模型与中文文本分类实践

    简介 2018年10月11日,Google发布的论文<Pre-training of Deep Bidirectional Transformers for Language Understan ...

  7. 如何使用BERT实现中文的文本分类(附代码)

    如何使用BERT模型实现中文的文本分类 前言 Pytorch readme 参数表 算法流程 1. 概述 2. 读取数据 3. 特征转换 4. 模型训练 5. 模型测试 6. 测试结果 7. 总结 前 ...

  8. 基于Bert+对抗训练的文本分类实现

    由于Bert的强大,它文本分类取得了非常好的效果,而通过对抗训练提升模型的鲁棒性是一个非常有研究意义的方向,下面将通过代码实战与大家一起探讨交流对抗训练在Bert文本分类领域的应用. 目录 一.Ber ...

  9. 毕业设计-基于 BERT 的中文长文本分类系统

    目录 前言 课题背景和意义 实现技术思路 一.文本分类的相关技术 二.文本表示模型 三.文本分类模型 实现效果图样例 最后 前言

最新文章

  1. linux下远程传输文件命令scp使用注解
  2. java jni 方法描述,五、JNI提供的函数介绍(一):类和对象操作
  3. ehviewer苹果版下载_苹果用户:支持ios的云手机有没有?在哪里下载云手机ios版?...
  4. Java IO - Reader
  5. mysql 不会联想字段_你有没有被MySQL的这个bug坑过?
  6. 大四课程设计之基于RFID技术的考勤管理系统(四)Qt界面设计
  7. 通过SD卡来安装Linux系统
  8. jquery颜色选择器
  9. 【springmvc】传值的几种方式postman接口测试
  10. 如何快速把英语单词导入有道词典
  11. 静态测试和动态测试有何区别
  12. 微信手机号授权解密失败问题现象和解决方法: getPhoneNumber
  13. gmail谷歌邮箱开启SMTP
  14. Bootstrap中tooltip插件使用 | 爱骇客
  15. Google Play的崩溃与ANR
  16. citespace:Your version‘s status cannot be verified due to network issue. Check your network conne
  17. 4567: [Scoi2016]背单词 trie+贪心
  18. 把单元格一分为二_怎么将一个单元格一分为二
  19. 对合成大西瓜修改图片的实践
  20. python 爬虫http2

热门文章

  1. 背景大小比率css,css – 如何计算背景大小百分比?
  2. css实现文字竖向排版
  3. 64位操作系统——(二)kernel
  4. 前端jpg和png的选择
  5. java写dnf_用java模拟dnf武器强化的过程
  6. SQL Server 2012连接不上服务器问题
  7. 智慧安防解决方案-最新全套文件
  8. 亿级数据量系统数据库性能优化方案
  9. Head First Design pattern Observer
  10. awtk开发实践——学习篇26: guage(表盘控件)