simpletransformers
simple-transformers-configuration

1. 导入相关模块

import warnings
warnings.simplefilter('ignore')import gc
import osimport numpy as np
import pandas as pdfrom sklearn.model_selection import StratifiedKFoldfrom simpletransformers.classification import ClassificationModel, ClassificationArgsos.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

2. 读取数据,并处理空值


train = pd.read_csv('data/train.csv')
test = pd.read_csv('data/test.csv')train['content'].fillna('', inplace=True)
test['content'].fillna('', inplace=True)

3. 设置模型的参数

TransformerModel具有dict参数,其中包含许多属性,这些属性提供对超参数的控制。

def get_model_args():model_args = ClassificationArgs()model_args.max_seq_length = 32 # 截取文本长度为128model_args.train_batch_size = 16model_args.num_train_epochs = 1 # 跑1epochmodel_args.sliding_window=True     # 使用滑动窗口model_args.evaluate_during_training = True # 训练过程中做评估model_args.evaluate_during_training_verbose = Truemodel_args.fp16 = Falsemodel_args.no_save = True # 不保存模型model_args.save_steps = -1 # 不根据step保存检查点model_args.overwrite_output_dir = True # 覆盖输出路径model_args.output_dir = dir    # 模型输出路径,默认为/outputsreturn model_args

4. single sentence classification 交叉验证训练模型

4.1 load标准预训练模型:huggingface标准预训练模型

model = ClassificationModel("roberta", "roberta-base"
)

4.2 load社区预训练模型 社区预训练模型

model = ClassificationModel("bert", "KB/bert-base-swedish-cased"
)

4.3 load本地预训练模型

outputs/best_model为本地保存模型的路径。

model = ClassificationModel("bert", "outputs/best_model"
)

4.4 完整交叉验证代码

oof = []
prediction = test[['id']]
prediction['bert_pred'] = 0n_folds = 3
kfold = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=2021)
for fold_id, (trn_idx, val_idx) in enumerate(kfold.split(train, train['label'])):train_df = train.iloc[trn_idx][['content', 'label']]valid_df = train.iloc[val_idx][['content', 'label']]train_df.columns = ['text', 'label']valid_df.columns = ['text', 'label']model_args = get_model_args()model = ClassificationModel('bert','hfl/chinese-roberta-wwm-ext',# 中文文本train的社区预训练模型args=model_args)model.train_model(train_df, eval_df=valid_df)#result, vaild_outputs, wrong_predictions = model.eval_model(valid_df)# 这里的result输出一些acc,f1之类的指标# vaild_outputs 输出的是softmax之前的那个权重# wrong_predictions 输出的错误的predict_, vaild_outputs, _  = model.eval_model(valid_df)df_oof = train.iloc[val_idx][['id', 'label']].copy()df_oof['bert_pred'] = vaild_outputs[:,1]oof.append(df_oof)print('predict')_, test_outputs = model.predict([text for text in test['content']])prediction['bert_pred'] += test_outputs[:, 1] / kfold.n_splitsdel model, train_df, valid_df, vaild_outputs, test_outputsgc.collect()

不同任务所对应的模型

Task Model
Binary and multi-class text classification ClassificationModel
Conversational AI (chatbot training) ConvAIModel
Language generation LanguageGenerationModel
Language model training/fine-tuning LanguageModelingModel
Multi-label text classification MultiLabelClassificationModel
Multi-modal classification (text and image data combined) MultiModalClassificationModel
Named entity recognition NERModel
Question answering QuestionAnsweringModel
Regression ClassificationModel
Sentence-pair classification ClassificationModel
Text Representation Generation RepresentationModel
Document Retrieval RetrievalModel

4.5 输出

df_oof = pd.concat(oof)
df_oof = df_oof.sort_values(by='id')
df_oof.head(10)
df_oof[['id', 'bert_pred']].to_csv('roberta_pred_oof.csv', index=False)
prediction[['id', 'bert_pred']].to_csv('roberta_pred_test.csv', index=False)

5. sentence pair classification 交叉验证训练模型

def get_model_args():model_args = ClassificationArgs()model_args.max_seq_length = 32 # 截取文本长度为128model_args.train_batch_size = 16model_args.num_train_epochs = 1 # 跑1epochmodel_args.sliding_window=True     # 使用滑动窗口model_args.evaluate_during_training = True # 训练过程中做评估model_args.evaluate_during_training_verbose = Truemodel_args.fp16 = Falsemodel_args.no_save = True # 不保存模型model_args.save_steps = -1 # 不根据step保存检查点model_args.overwrite_output_dir = True # 覆盖输出路径model_args.output_dir = dir    # 模型输出路径,默认为/outputsreturn model_args
oof = []
prediction = test[['id']]
prediction['bert_pred'] = 0n_folds = 3
kfold = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=2021)
for fold_id, (trn_idx, val_idx) in enumerate(kfold.split(train, train['label'])):train_df = train.iloc[trn_idx][['level_4', 'content', 'label']]valid_df = train.iloc[val_idx][['level_4', 'content', 'label']]train_df.columns = ['text_a', 'text_b', 'label']valid_df.columns = ['text_a', 'text_b', 'label']model_args = get_model_args()model = ClassificationModel('bert','hfl/chinese-roberta-wwm-ext',# 中文文本train的社区预训练模型num_labels=2,args=model_args)model.train_model(train_df, eval_df=valid_df)#result, vaild_outputs, wrong_predictions = model.eval_model(valid_df)# 这里的result输出一些acc,f1之类的指标# vaild_outputs 输出的是softmax之前的那个权重# wrong_predictions 输出的错误的predict_, vaild_outputs, _  = model.eval_model(valid_df)df_oof = train.iloc[val_idx][['id', 'label']].copy()df_oof['bert_pred'] = vaild_outputs[:,1]oof.append(df_oof)print('predict')_, test_outputs = model.predict([list(text) for text in test[['level_4', 'content']].values])prediction['bert_pred'] += test_outputs[:, 1] / kfold.n_splitsdel model, train_df, valid_df, vaild_outputs, test_outputsgc.collect()
df_oof = pd.concat(oof)
df_oof = df_oof.sort_values(by='id')
df_oof.head(10)
df_oof[['id', 'bert_pred']].to_csv('roberta_pred_oof.csv', index=False)
prediction[['id', 'bert_pred']].to_csv('roberta_pred_test.csv', index=False)

6. sentence-transformers

获取文本相关性

  • 直接使用预训练模型,获取文本相关性
  • 使用训练样本微调之后,获取文本相关性
import numpy as np
import torch
from sentence_transformers import SentenceTransformer, util

simpletransformers的 single sentence classification和sentence pair classification相关推荐

  1. Convolutional Neural Networks for Sentence Classification(卷积神经网络句子分类)

    目录 摘要 原文 翻译 单词解释 技术解读 引言 原文 翻译 单词解释 技术解读 原文 翻译 单词解释 技术解读 原文 翻译 单词解释 技术解读 原文 翻译 单词解释 技术解读. Model 原文 单 ...

  2. 文献阅读笔记 # Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks

    <Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks> 用于快速搭建NLP任务的demo的开源项目sbert的原始 ...

  3. Chinese Relation Extraction by BiGRU with Character and Sentence Attentions之代码理解

      代码链接为 https://github.com/crownpku/Information-Extraction-Chinese/tree/master/RE_BGRU_2ATT . 1. ini ...

  4. java sentence_Java Sentence類代碼示例

    本文整理匯總了Java中aima.core.logic.propositional.parsing.ast.Sentence類的典型用法代碼示例.如果您正苦於以下問題:Java Sentence類的具 ...

  5. Tokenisation word segmentation sentence segmentation

    David D. Palmer Chapter 2: Tokenisation and SentenceSegmentation.2000 https://scholar.google.com/cit ...

  6. CGMH: Constrained Sentence Generation by Metropolis-Hastings Sampling

    Abstract 在自然语言生成的实际应用中,除了流畅度和自然度的要求外,通常还有一些其他的约束. 已有的一些语言生成技术基于 RNN 实现,对于这类方法,不容易在维持生成质量的同时对其添加约束. 文 ...

  7. 专题-句向量(Sentence Embedding)

    原始地址:https://github.com/imhuay/Algorithm_Interview_Notes-Chinese/blob/master/B-%E8%87%AA%E7%84%B6%E8 ...

  8. simple sentence to complex

    目录 terms to note 1 loose sentence 2 periodic sentence圆周句.掉尾句 writing concise sentence how to write c ...

  9. Hierarchical Attention Networks for Document Classification(HAN)

    HAN历史意义: 1.基于Attention的文本分类模型得到了很多关注 2.通过层次处理长文档的方式逐渐流行 3.推动了注意力机制在非Seqseq模型上的应用 前人主要忽视的问题: 1.文档中不同句 ...

最新文章

  1. 小红帽怎样装图形化界面_linux安装图形化界面
  2. 【译】Using Machine Learning to Understand the Ethereum Blockchain
  3. Oracle10.2.0.1.0升级Oracle10.2.0.2.0补丁安装指南(转载)
  4. 大话数据结构15 : 线索二叉树
  5. python canvas画移动物体_Python GUI编程入门(25)-移动Canvas对象
  6. 深度ip转换器手机版app_手机大师智能管家app下载-手机大师智能管家app官网版 v1.0.0...
  7. Ubuntu扩展触摸屏触控错位修复
  8. 华为、三星都崴了脚:石墨烯充电还有戏吗
  9. Hadoop HIVE 安装配置(单机集群)
  10. 中国水下充气袋行业市场供需与战略研究报告
  11. php爬虫大数据抓取_数据分析|爬虫抓取东方财富网股吧帖子
  12. ImageButton
  13. python 菜鸟教程 xml-【读书】Django教程(菜鸟教程)
  14. PIL IOError: cannot identify image file './temp.jpg'
  15. Java获取外网ip地址
  16. P1179 数字统计
  17. 一起学爬虫(Python) — 19 年轻人,进来学自动化
  18. pxe网络安装服务器的部署
  19. 手把手教你如何连接阿里云RDS云数据库
  20. 妙啊!巧用 SSH 突破限制穿透内网

热门文章

  1. RTOS面试常问题目
  2. 学习html/css基础的重点笔记
  3. 关于机械臂仿真软件的简介
  4. mysql 取前几分钟和几秒,mysql 数据库取前后几秒 几分钟 几小时 几天的语句
  5. matlab 眼图 值,Matlab通信仿真——带限系统下的基带信号
  6. Problem:机器翻译
  7. 怎样将计算机硬盘的资料彻底删除吗,3种技巧|如何从USB永久删除/清除文件
  8. 到底什么是ERP系统
  9. Php Adodb 初探
  10. 信息安全快讯丨生日快乐,我的国