自动寻找Prompt

实验版本好多参数可调

import os
import torch
import logging
import datasets
import transformers
import numpy as np
import torch.nn as nn
from sklearn import metrics
from datasets import Dataset
from torch.nn import CrossEntropyLoss
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from transformers import Trainer, TrainingArguments, BertTokenizer, BertForMaskedLM
from transformers.modeling_outputs import MaskedLMOutputos.environ['CUDA_VISIBLE_DEVICES'] = '1'
transformers.set_seed(1)
logging.basicConfig(level=logging.INFO)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
prp_len = 2 #prompt token长度# 通过LSTM寻找prompt的embedding
class MyModel(BertForMaskedLM):def __init__(self, config):super().__init__(config)self.dim = 384self.emb = nn.Embedding(prp_len+1, self.dim)self.bi_lstm = nn.LSTM(self.dim, self.dim, 2, bidirectional=True)self.b_emb = self.get_input_embeddings()self.line1 = nn.Linear(768, 768)self.line2 = nn.Linear(768, 768)self.line3 = nn.Linear(768, 768)self.relu = nn.ReLU()def forward(self,input_ids=None,  # [CLS] e(p) e(p) [MASK] e(input_ids)attention_mask=None,token_type_ids=None,position_ids=None,head_mask=None,inputs_embeds=None,encoder_hidden_states=None,encoder_attention_mask=None,labels=None,   # [CLS] -100 -100 label e(input_ids)output_attentions=None,output_hidden_states=None,return_dict=None,):p = self.emb(torch.LongTensor([range(1, prp_len+1)]*input_ids.shape[0]).to(device))  # 若用GPU则要注意将数据导入cudap = self.bi_lstm(p)[0]p = self.relu(self.line1(p))p = self.relu(self.line2(p))p = self.relu(self.line3(p))inputs_embeds = self.b_emb(input_ids)inputs_embeds[:, 1:prp_len+1, :] = p return_dict = return_dict if return_dict is not None else self.config.use_return_dictoutputs = self.bert(None,attention_mask=attention_mask,token_type_ids=token_type_ids,position_ids=position_ids,head_mask=head_mask,inputs_embeds=inputs_embeds,encoder_hidden_states=encoder_hidden_states,encoder_attention_mask=encoder_attention_mask,output_attentions=output_attentions,output_hidden_states=output_hidden_states,return_dict=return_dict,)sequence_output = outputs[0]prediction_scores = self.cls(sequence_output)masked_lm_loss = Noneif labels is not None:loss_fct = CrossEntropyLoss()  # -100 index = padding tokenmasked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))if not return_dict:output = (prediction_scores,) + outputs[2:]return ((masked_lm_loss,) + output) if masked_lm_loss is not None else outputreturn MaskedLMOutput(loss=masked_lm_loss,logits=prediction_scores,hidden_states=outputs.hidden_states,attentions=outputs.attentions,)class LecCallTag():# 原始样本统计def data_show(self, data_file):with open(data_file, 'r', encoding='utf-8') as f:data = f.readlines()logging.info("获取数据:%s" % len(data))tags_data_dict = {}for line in data:text_label = line.strip().split('\t')if text_label[1] in tags_data_dict:tags_data_dict[text_label[1]].append(text_label[0])else:tags_data_dict[text_label[1]] = [text_label[0]]logging.info("其中,各分类数量:")for k, v in tags_data_dict.items():logging.info("%s: %s" % (k, len(v)))return tags_data_dict# 数据处理def data_process(self, data_file):with open(data_file, 'r', encoding='utf-8') as f:data = [line.strip().split('\t') for line in f.readlines()]self.lable2idx1 = {'天气好': '好', '天气良': '良', '天气差': '差', '其他': '无'}text = ['[MASK]'*(prp_len+1)+_[0] for _ in data]label = [self.lable2idx1[_[1]]*(prp_len+1)+_[0] for _ in data]return text, label# model, tokenizerdef create_model_tokenizer(self, model_name, n_label=0):tokenizer = BertTokenizer.from_pretrained(model_name)model = MyModel.from_pretrained(model_name)return tokenizer, model# 构建datasetdef create_dataset(self, text, label, tokenizer, max_len):X_train, X_test, Y_train, Y_test = train_test_split(text, label, test_size=0.2, random_state=1)logging.info('训练集:%s条,\n测试集:%s条' %(len(X_train), len(X_test)))train_dict = {'text': X_train, 'label_text': Y_train}test_dict = {'text': X_test, 'label_text': Y_test}train_dataset = Dataset.from_dict(train_dict)test_dataset = Dataset.from_dict(test_dict)def preprocess_function(examples):text_token = tokenizer(examples['text'], padding=True,truncation=True, max_length=max_len)text_token['labels'] = np.array(tokenizer(examples['label_text'], padding=True,truncation=True, max_length=max_len)["input_ids"])text_token['labels'][:, 1:prp_len+1] = -100  # 占位,计算loss时忽略-100# print('text_token', text_token)return text_tokentrain_dataset = train_dataset.map(preprocess_function, batched=True)test_dataset = test_dataset.map(preprocess_function, batched=True)return train_dataset, test_dataset# 构建trainerdef create_trainer(self, model, train_dataset, test_dataset, checkpoint_dir, batch_size):args = TrainingArguments(checkpoint_dir,evaluation_strategy = "epoch",learning_rate=2e-5,per_device_train_batch_size=batch_size,per_device_eval_batch_size=batch_size,num_train_epochs=20,weight_decay=0.01,load_best_model_at_end=True,metric_for_best_model='accuracy',)def compute_metrics(pred):# labels = pred.label_ids# preds = pred.predictions.argmax(-1)labels = pred.label_ids[:, prp_len+1]preds = pred.predictions[:, prp_len+1].argmax(-1)precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')acc = accuracy_score(labels, preds)return {'accuracy': acc, 'f1': f1, 'precision': precision, 'recall': recall}trainer = Trainer(model,args,train_dataset=train_dataset,eval_dataset=test_dataset,# tokenizer=tokenizer,compute_metrics=compute_metrics)return trainerdef main():lct = LecCallTag()data_file = '/data.txt'checkpoint_dir = "/checkpoint/"batch_size = 16max_len = 64n_label = 3tags_data = lct.data_show(data_file)text, label = lct.data_process(data_file)tokenizer, model = lct.create_model_tokenizer("bert-base-chinese")train_dataset, test_dataset = lct.create_dataset(text, label, tokenizer, max_len)trainer = lct.create_trainer(model, train_dataset, test_dataset, checkpoint_dir, batch_size)trainer.train()pred = trainer.predict(test_dataset)pred_label = np.argmax(pred[0][:, prp_len+1], axis=1).tolist()true_label = pred[1][:, prp_len+1].tolist()print(metrics.classification_report(true_label, pred_label))print(metrics.confusion_matrix(true_label, pred_label))if __name__ == '__main__':main()

结果记录

baseline: prompt长度4、LSTM层数2、LSTM输出维度768、前馈层数1、mask词:优良差无、learning_rate:2e-5


最好的情况是手写的prompt,但都要比直接用bert做分类任务效果好一些,小样本置信度不高,增加一种文本分类的手段。

基于Prompt的MLM文本分类-v2相关推荐

  1. 基于Prompt的MLM文本分类

    简介 常规NLP做文本分类时常用Transfer Learning的方式,在预训练bert上加一个分类层,哪个输出节点概率最大则划分到哪一类别.而基于Prompt的MLM文本分类是将文本分类任务转化为 ...

  2. 基于Prompt的MLM文本分类 bert4keras实现

    本文主要介绍使用Prompt的MLM文本分类 bert4keras的代码实现,用以丰富bert4keras框架的例子 关于prompt的原理的文章网上有很多优秀的文章可以自行百度. github地址 ...

  3. 【调研】基于Prompt的小样本文本分类调研:PET,LM-BFF,KPT,PTR

    本篇博客一共要分享四篇prompt论文,它们分别提出了四个模型. 目录

  4. 基于深度学习的文本分类 3

    基于深度学习的文本分类 Transformer Transformer是一种完全基于Attention机制来加速深度学习训练过程的算法模型,其最大的优势在于其在并行化处理上做出的贡献.换句话说,Tra ...

  5. 文本基线怎样去掉_ICML 2020 | 基于类别描述的文本分类模型

    论文标题: Description Based Text Classification with Reinforcement Learning 论文作者: Duo Chai, Wei Wu, Qing ...

  6. 【项目实战课】NLP入门第1课,人人免费可学,基于TextCNN的新闻文本分类实战...

    欢迎大家来到我们的项目实战课,本期内容是<基于TextCNN的新闻文本分类实战>. 所谓项目课,就是以简单的原理回顾+详细的项目实战的模式,针对具体的某一个主题,进行代码级的实战讲解,可以 ...

  7. ICML 2020 | 基于类别描述的文本分类模型

    论文标题: Description Based Text Classification with Reinforcement Learning 论文作者: Duo Chai, Wei Wu, Qing ...

  8. 基于深度学习的文本分类1

    基于深度学习的文本分类 与传统机器学习不同,深度学习既提供特征提取功能,也可以完成分类的功能.从本章开始我们将学习如何使用深度学习来完成文本表示. 现有文本表示方法的缺陷 在上一章节,我们介绍几种文本 ...

  9. Datawhale NLP入门:Task5 基于深度学习的文本分类2

    Task5 基于深度学习的文本分类2 在上一章节,我们通过FastText快速实现了基于深度学习的文本分类模型,但是这个模型并不是最优的.在本章我们将继续深入. 基于深度学习的文本分类 本章将继续学习 ...

最新文章

  1. 查看 mysql 占用的内存大小_mysql查看数据库和表的占用空间大小
  2. 阿拉德之怒显示服务器错误,阿拉德之怒网络异常怎么办 安装失败怎么办
  3. linux ftp服务器搭建及用户的分配,Linux搭建FTP服务器
  4. python爬取岗位数据并分析_区块链岗位薪资高,Python爬取300个区块链岗位分析,龙虎榜出炉...
  5. Getting the right Exception Context from a Memory dump Fixed
  6. jvm内存模型_JVM内存模型的相关概念
  7. java利用htmlparser得到网页html内容
  8. Oracle分页模板
  9. android日历信息获取错误,android – 从日历中获取事件
  10. 30 校准_机会难得校准实验室认可培训别再错过
  11. 网络工程师linux题,历年软考网络工程师Linux真题详解
  12. 智慧旅游系统总体设计方案
  13. Linux应用开发自学之路
  14. php gd2扩展_PHP如何打开gd2扩展库
  15. B样条曲线与曲面相关知识点汇总
  16. 移动开发者如何获取免费流量
  17. oracle index alter,Oracle alter index rebuild 一系列问题
  18. 重庆大学计算机学院课题组,【计算机】计算机学院关于智能计算的大规模优化学术报告圆满结束...
  19. Python常用配置文件ini、json、yaml及python字典读写总结
  20. 用1、3、5、7 这4 个数字,能组成的互不相同且无重复数字的三位数有哪些?共有多少个?这些数的和为多少?

热门文章

  1. 服务器网络打印总是自动删除,打印机无法打印打印文件时会自动删除,怎么回事啊?...
  2. java web网站 js 简体繁体切换_求繁简转换的js代码,可以设置打开网站时候整站默认显示繁体或简体,然后可以手动切换繁简。...
  3. Java基础篇--Java 数组
  4. springboot tomcat 调优
  5. Unity3d背包系统(三)—— 设计物品类的JSON文件
  6. 路由器二次开发一步一步把工业路由器变成一个高端的可指定出网、节点和链路的路由器,包含详细过程及快捷脚本(四)
  7. 【2023/05/08】雅卡尔织布机
  8. iOS 判断 iPhoneXS Max、iPhoneXS、iPhoneXR、iPhoneX
  9. PA认证必看:考试说明及注意事项
  10. python具体能做什么_python都能干嘛