记录一下:
环境:谷歌Colab (带GPU)
数据集:爬虫爬取的电商数据集,107个分类,35W条数据
依赖库:huggingface-hub-0.10.1 sklearn-0.0 tokenizers-0.13.1 transformers-4.23.1

!pip install sklearn transformers pandas tensorflow
from sklearn.model_selection import train_test_split
from transformers import TFBertForSequenceClassification
import pandas as pd
import torch
import os
import tensorflow as tf
from transformers import BertTokenizer

采用的是将数据先传到谷歌云盘,再挂载云盘就可以轻松读入数据了~
挂载云盘可以使用命令,也可以直接设置,方便了很多

df_raw = pd.read_csv('/content/drive/MyDrive/data_use_y.csv', encoding='utf-8-sig', index_col=0)
df_raw.head()

数据如下:

用了两个模型进行训练测试,最大长度这里设置的是20,不过看数据,其实还可以再设置更大一些。batch_size设置一次为500,epoch稍微大一些的方式进行训练。

# tokenizer = BertTokenizer.from_pretrained('xlm-roberta-base')
tokenizer = BertTokenizer.from_pretrained('hfl/chinese-bert-wwm-ext')
max_length = 20
batch_size = 500
def split_dataset(df):train_set, x = train_test_split(df,stratify=df['label'],test_size=0.1,random_state=42)val_set, test_set = train_test_split(x,stratify=x['label'],test_size=0.5,random_state=43)return train_set,val_set, test_setdef map_example_to_dict(input_ids, attention_masks, token_type_ids, label):return {"input_ids": input_ids,"token_type_ids": token_type_ids,"attention_mask": attention_masks,}, labeldef encode_examples(ds, limit=-1):# prepare list, so that we can build up final TensorFlow dataset from slices.input_ids_list = []token_type_ids_list = []attention_mask_list = []label_list = []if (limit > 0):ds = ds.take(limit)for index, row in ds.iterrows():review = row["text"]label = row["y"]bert_input = convert_example_to_feature(review)input_ids_list.append(bert_input['input_ids'])token_type_ids_list.append(bert_input['token_type_ids'])attention_mask_list.append(bert_input['attention_mask'])label_list.append([label])return tf.data.Dataset.from_tensor_slices((input_ids_list, attention_mask_list, token_type_ids_list, label_list)).map(map_example_to_dict)

划分训练集、验证集(用于调参)、测试集(用于查验模型最终的效果)

train_data, val_data, test_data = split_dataset(df_raw)
# 对数据集进行编码
# train dataset
ds_train_encoded = encode_examples(train_data).shuffle(10000).batch(batch_size)
# val dataset
ds_val_encoded = encode_examples(val_data).batch(batch_size)
# test dataset
ds_test_encoded = encode_examples(test_data).batch(batch_size)
# 加载模型
# model = TFBertForSequenceClassification.from_pretrained("xlm-roberta-base", num_labels=len(set(df_raw['label'].tolist())))
model = TFBertForSequenceClassification.from_pretrained("hfl/chinese-bert-wwm-ext", num_labels=len(set(df_raw['label'].tolist())))
learning_rate = 2e-5
# learning_rate = 2e-5# 20其实还是有点过拟合了
number_of_epochs = 20
# 设置优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, epsilon=1e-08, clipnorm=1)
# optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, epsilon=1e-08, clipnorm=1)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
# ----使用Chinese bert 全词掩码的方式
bert_history = model.fit(ds_train_encoded, epochs=number_of_epochs, validation_data=ds_val_encoded)

训练日志如下:

可以看到,其实在9个epoch时模型基本趋于稳定,后面反而略有降低,但最终保持在0.85以上。
查看模型在测试集上的效果:

model.evaluate(ds_test_encoded)


在测试集上准确率也在0.85以上,后面用roberta训练测试准确率也在0.83的水平,可见即便在没有进行精细的数据预处理条件下,bert仍有不错的表现,但如何进一步提升多分类效果值得进一步探究。

【huggingface transformers笔记】基于Bert的中文电商文本分类相关推荐

  1. 【文本分类】基于BERT预训练模型的灾害推文分类方法、基于BERT和RNN的新闻文本分类对比

    ·阅读摘要: 两篇论文,第一篇发表于<图学学报>,<图学学报>是核心期刊:第二篇发表于<北京印刷学院学报>,<北京印刷学院学报>没有任何标签. ·参考文 ...

  2. 基于 BERT 实现的情感分析(文本分类)----概念与应用

    文章目录 基于 BERT 的情感分析(文本分类) 基本概念理解 简便的编码方式: One-Hot 编码 突破: Word2Vec编码方式 新的开始: Attention 与 Transformer 模 ...

  3. 使用Bert预训练模型进行中文文本分类(基于pytorch)

    前言 最近在做一个关于图书系统的项目,需要先对图书进行分类,想到Bert模型是有中文文本分类功能的,于是打算使用Bert模型进行预训练和实现下游文本分类任务 数据预处理 2.1 输入介绍 在选择数据集 ...

  4. 何使用BERT模型实现中文的文本分类

    原文网址:https://blog.csdn.net/Real_Brilliant/article/details/84880528 如何使用BERT模型实现中文的文本分类 前言 Pytorch re ...

  5. 基于协同训练的半监督文本分类算法

    标签: 半监督学习,文本分类 作者:炼己者 --- 本博客所有内容以学习.研究和分享为主,如需转载,请联系本人,标明作者和出处,并且是非商业用途,谢谢! 如果大家觉得格式看着不舒服,也欢迎大家去看我的 ...

  6. 基于Kaggle数据的词袋模型文本分类教程

     基于Kaggle数据的词袋模型文本分类教程 发表于23小时前| 454次阅读| 来源FastML| 0 条评论| 作者Zygmunt Z 词袋模型文本分类word2vecn-gram机器学习 w ...

  7. 基于Keras搭建CNN、TextCNN文本分类模型

    基于Keras搭建CNN.TextCNN文本分类模型 一.CNN 1.1 数据读取分词 1.2.数据编码 1.3 数据序列标准化 1.4 构建模型 1.5 模型验证 二.TextCNN文本分类 2.1 ...

  8. 【文本分类】基于改进TF-IDF特征的中文文本分类系统

    摘要:改进TFIDF,提出相似度因子,提高了文本分类准确率. 参考文献:[1]但唐朋,许天成,张姝涵.基于改进TF-IDF特征的中文文本分类系统[J].计算机与数字工程,2020,48(03):556 ...

  9. AI深度学习入门与实战21 文本分类:用 Bert 做出一个优秀的文本分类模型

    在上一讲,我们一同了解了文本分类(NLP)问题中的词向量表示,以及简单的基于 CNN 的文本分类算法 TextCNN.结合之前咱们学习的 TensorFlow 或者其他框架,相信你已经可以构建出一个属 ...

最新文章

  1. [转]iOS5 ARC学习笔记:strong、weak等详解
  2. fuzzy k means
  3. Spring底层控制反转解耦合(IOC)
  4. CodeForces - 786BLegacy——线段树建图+最短路
  5. 3-1:类与对象入门——类的引入和类的定义以及访问限定符和封装还有对面向对象的理解
  6. 潘维良(帮别人名字作诗)
  7. http 请求中的 referer
  8. 打飞机小游戏,附带源码
  9. 印刷点阵字体_印刷术—类型族,分类和组合字体
  10. 思科模拟器-实验 18 三层交换访问控制列表配置
  11. python如何横向输出_python数据竖着怎么变横的?
  12. 收发器(Transceiver)架构5——发信机2
  13. 复数的幅角Arg与幅角主值arg
  14. apt dpkg 错误制造
  15. ArcBlock ⑦ 月报 | Forge 框架升级更新 开发者社区建设如火如荼
  16. 微软产品经理:你不能不知道的 6 个 Web 开发者工具
  17. Windows Oracle ODBC驱动数据源安装配置
  18. 总结kali中文输入法失败的原因
  19. 什么是电子邮箱地址?好用的电子邮箱注册申请
  20. 【通信】通信相关的一些概念

热门文章

  1. springboot-方法处理4-消息转换器
  2. bzoj乱刷计划2 19/20
  3. 嵌入式系统之nfs挂载-在嵌入式系统和linux之间拷贝文件
  4. C0语言解释执行程序,C0编译器”案例概述.ppt
  5. matlab绘制簇状图,用matplotlib自定义绘制柱形图
  6. HCIP的基础知识点(详细)
  7. JSONArray遍历
  8. 压缩减小图像大小技巧:8个最佳 JPEG 图像压缩软件
  9. Android使用ttf字体库替代替图片
  10. 喜报!《大数据》72篇论文入选中国知网《学术精要数据库》高影响力论文!...