主要内容:

使用torch和huggingface写二分类demo。

1.类别定义,将文本存放在list中,将label存放在另一个list中,这里举个二分类的的例子,输入类别用[0, 1]。如果是多分类,那么输入类别[0,1,2,.....n]。这里要求文本在'text'的位置跟类别在'target'中的位置对应。

2.对输入数据编码。汉字肯定不可直接作为模型的输入,将其根据词典进行编码,最后一堆数字输入到了模型中。

3.将文本和标签绑定,序列化到Dataloader中,开始训练模型。

4.优化器使用adam,然后开始更新参数,不断迭代,直到训练停止。

# -*- encoding:utf-8 -*-
import random
import torch
from torch.utils.data import TensorDataset, DataLoader, random_split
from transformers import BertTokenizer, BertConfig
from transformers import BertForSequenceClassification, AdamW
from transformers import get_linear_schedule_with_warmup
from sklearn.metrics import f1_score, accuracy_score
import numpy as np# tokenizer用来对文本进行编码
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
# 训练数据
train = {'text': [' 测试good','美团 学习',' 测试good','美团 学习',' 测试good','美团 学习',' 测试good','美团 学习'],'target': [0, 1, 0, 1, 0, 1, 0, 1],
}# Get text values and labels
text_values = train['text']
labels = train['target']print('Original Text : ', text_values[0])
print('Tokenized Ids: ', tokenizer.encode(text_values[0], add_special_tokens = True))
print('Tokenized Text: ', tokenizer.decode(tokenizer.encode(text_values[0], add_special_tokens = True)))
print('Token IDs     : ', tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text_values[0])))# Function to get token ids for a list of texts
def encode_fn(text_list):all_input_ids = []for text in text_list:input_ids = tokenizer.encode(text,add_special_tokens = True,  # 添加special tokens, 也就是CLS和SEPmax_length = 160,           # 设定最大文本长度pad_to_max_length = True,   # pad到最大的长度return_tensors = 'pt'       # 返回的类型为pytorch tensor)all_input_ids.append(input_ids)all_input_ids = torch.cat(all_input_ids, dim=0)return all_input_ids# 对训练数据进行编码
all_input_ids = encode_fn(text_values)
labels = torch.tensor(labels)# 训练参数定义
epochs = 1
batch_size = 1# Split data into train and validation
dataset = TensorDataset(all_input_ids, labels)
train_size = int(0.75 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])# Create train and validation dataloaders
train_dataloader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
val_dataloader = DataLoader(val_dataset, batch_size = batch_size, shuffle = False)# Load the pretrained BERT model, num_labels=2表示类别是2
model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=2, output_attentions=False, output_hidden_states=True)
print(model)
# model.cuda()# create optimizer and learning rate schedule
optimizer = AdamW(model.parameters(), lr=2e-5)
total_steps = len(train_dataloader) * epochs
# 表示学习率预热num_warmup_steps步后,再按照指定的学习率去更新参数
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)flag = False
total_batch,  last_improve = 0, 0
require_improvement = 1000
for epoch in range(epochs):model.train()total_loss, total_val_loss = 0, 0# 开始训练for step, batch in enumerate(train_dataloader):# 梯度清零model.zero_grad()# 计算lossloss, logits, hidden_states = model(batch[0], token_type_ids=None, attention_mask=(batch[0] > 0),labels=batch[1])total_loss += loss.item()# 梯度回传loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)# 梯度更新optimizer.step()scheduler.step()# model.eval()表示模型切换到eval模式,表示不会更新参数,只有在train模式下,才会更新梯度参数model.eval()for i, batch in enumerate(val_dataloader):with torch.no_grad():loss, logits, hidden_states = model(batch[0], token_type_ids=None, attention_mask=(batch[0] > 0),labels=batch[1])print(loss, logits)total_val_loss += loss.item()logits = logits.detach().cpu().numpy()label_ids = batch[1].to('cpu').numpy()avg_val_loss = total_val_loss / len(val_dataloader)if avg_val_loss < dev_best_loss:    last_improve = total_batchif total_batch - last_improve > require_improvement:# 验证集loss超过1000batch没下降,结束训练print("No optimization for a long time, auto-stopping...")flag = Truebreakif flag:breaktotal_batch = total_batch + 1

Bert系列:BERT模型二分类demo以及讲解相关推荐

  1. 【实战】使用Bert微调完成文本二分类

    使用Bert微调完成文本二分类 1.训练前准备 指定训练和预测的gpu 读取数据.分析数据 构造训练数据 2.模型定义.训练和测试代码 定义模型 测试代码 训练代码 3.微调 4.预测.批量预测 实验 ...

  2. 【回答问题】ChatGPT上线了!推荐30个以上比较好的中文bert系列的模型/压缩模型

    推荐30个以上比较好的中文bert系列的模型 以下是一些中文 BERT 系列的模型: BERT-Base, Chinese: 中文 BERT 基础版 BERT-Large, Chinese: 中文 B ...

  3. bert中文情感分析二分类任务详解

    查看GPU版本和使用情况 import torch if torch.cuda.is_available():device = torch.device("cuda")print( ...

  4. 神经网络系列之五 -- 线性二分类的方法与原理

    https://www.cnblogs.com/woodyh5/p/12101581.html 系列博客,原文在笔者所维护的github上:https://aka.ms/beginnerAI, 点击s ...

  5. Bert系列(二)——源码解读之模型主体

    本篇文章主要是解读模型主体代码modeling.py.在阅读这篇文章之前希望读者们对bert的相关理论有一定的了解,尤其是transformer的结构原理,网上的资料很多,本文内容对原理部分就不做过多 ...

  6. 二分类问题:基于BERT的文本分类实践!附完整代码

    Datawhale 作者:高宝丽,Datawhale优秀学习者 寄语:Bert天生适合做分类任务.文本分类有fasttext.textcnn等多种方法,但在Bert面前,就是小巫见大巫了. 推荐评论展 ...

  7. NLP之PTM:自然语言处理领域—预训练大模型时代的各种吊炸天大模型算法概述(Word2Vec→ELMO→Attention→Transfo→GPT系列/BERT系列等)、关系梳理、模型对比之详细攻略

    NLP之PTM:自然语言处理领域-预训练大模型时代的各种吊炸天大模型算法概述(Word2Vec→ELMO→Attention→Transformer→GPT系列/BERT系列等).关系梳理.模型对比之 ...

  8. 小白Bert系列-生成pb模型,tfserving加载,flask进行预测

    bert分类模型使用tfserving部署. bert模型服务化现在已经有对应开源库部署. 例如:1.https://github.com/macanv/BERT-BiLSTM-CRF-NER 该项目 ...

  9. Bert系列:如何用bert模型输出文本的embedding

    问题: 分类模型可以输出其文本的embedding吗?LM模型可以输出其文本的embedding吗?答案:可以. 假设你已经用自己的数据fine-tuing好模型. 主要工具设备型号: python3 ...

最新文章

  1. ASMSupport教程4.2
  2. 计算机视觉与深度学习 | 使用MATLAB实现图像SURF特征的提取与匹配以及目标定位(代码类)
  3. Mybatis Generator 逆向生成器
  4. wince -- 线程中SetEvent及WaitForSingleObject用法
  5. Nginx 架构——【核心流程+模块介绍】
  6. OpenCV精进之路(二十三):实例——Bag of Features(BoF)图像分类实践
  7. RemObjects
  8. 2020-08-21 第一次面试小结
  9. matlab一个m文件定义多个函数,matlab怎么在一个m文件中写多个函数?
  10. LVS-NAT基于NFS存储部署Discuz
  11. 有return的情况下try catch finally的执行顺序(最有说服力的总结) 后面的神评论
  12. 弘扬时代新风建设网络文明,小趣带你揭秘肾透明细胞癌致瘤机制
  13. 如何做好性能压测(一) | 压测环境的设计和搭建
  14. Tilemap瓦片地图
  15. 仿写携程旅游手机浏览器页面
  16. Python 爬虫js加密破解(四) 360云盘登录password加密
  17. C1认证学习笔记(第四章)
  18. 从苹果、SpaceX等高科技企业的产品发布会看企业产品战略和敏捷开发的关系
  19. 苹果cms模板_首涂第三套苹果CMSv10自适应视频站模板
  20. 大地坐标和高斯平面坐标转换

热门文章

  1. 马云的创业故事及他人生中的摆渡人-第一个双十一(九)
  2. 从100个男人里面挑选37个, 问这里面存在最优秀的男人的概率是多少?
  3. 无用的“数据”?有用的“大数据”
  4. Photoshop中画斜线的方法
  5. 问卷星自动填写JavaScrip脚本使用教程
  6. java相机开发_控制相机  |  Android 开发者  |  Android Developers
  7. Maven打包速率优化
  8. 红米k30pro网速测试软件,红米K30Pro和一加手机8Pro性能对比:依然有差距比想象的要大...
  9. idea 代码提示区分大小写
  10. 微信小程序searchbar搜索功能的使用