·请参考本系列目录:【BERT-多标签文本分类实战】之一——实战项目总览
·下载本实战项目资源:>=点击此处=<

[1] 损失函数与评价指标

  多标签文本分类任务,用的损失函数是BCEWithLogitsLoss,不是交叉熵损失函数cross_entropy!!

BCEWithLogitsLosscross_entropy有什么区别?
+
1)cross_entropy它就是算单标签的损失的,大家去看一下它的公式,它对一个文本只取概率最大的那个标签;
+
2)BCEWithLogitsLoss对模型输出取的是sigmoid,而cross_entropy对模型的输出取的是softmaxsigmoidsoftmax虽然都是把一组数据放缩到[0,1]区间,但是softmax具有排斥性,放缩后的一组数据之和为1,所以这样一组标签概率只会有一个较大值;而sigmoid也是把一组数据放缩到[0,1]区间,但它更类似于等比例缩放,原来大的数现在还大,可以有多个较大的概率存在,所以sigmoid更适合在多标签文本分类任务中。所以要使用BCEWithLogitsLoss

  本次实战项目中使用的评价指标有:准确率accuracy、精确率precision、汉明损失hamming_loss。是基于sklearn库实现的。

# 计算多标签准确率、精确率、hm
def APH(y_true, y_pred):return metrics.accuracy_score(y_true, y_pred), \metrics.precision_score(y_true, y_pred, average='samples'), \metrics.hamming_loss(y_true, y_pred)

还有其他评价指标,召回率、F1等等,评价指标还分可为micro和macro,种类较多,可以参考地址:https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics。

[2] 采样

  采样是指:把模型输出出来的概率,转化成独热数组,通常使用阈值为0.5的阈值函数,即概率大于0.5的标签采样为1,否则为0。本项目设置阈值为0.4、且只取2个标签。

# 预测多标签的输出,把概率值转化为独热数组
def Predict(outputs, alpha=0.4):predic = torch.sigmoid(outputs)zero = torch.zeros_like(predic)topk = torch.topk(predic, k=2, dim=1, largest=True)[1]for i, x in enumerate(topk):for y in x:if predic[i][y] > alpha:zero[i][y] = 1return zero.cpu()

[3] 训练

  训练代码如下:

def train(config, model, train_iter, dev_iter, test_iter, is_write):start_time = time.time()model.train()# 普通算法# optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)# bert算法param_optimizer = list(model.named_parameters())no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']optimizer_grouped_parameters = [{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]# BertAdam implements weight decay fix,# BertAdam doesn't compensate for bias as in the regular Adam optimizer.optimizer = AdamW(optimizer_grouped_parameters,lr=config.learning_rate,eps=1e-8)# 学习率指数衰减,每次epoch:学习率 = gamma * 学习率scheduler = get_linear_schedule_with_warmup(optimizer,num_warmup_steps = 0,num_training_steps = len(train_iter) * config.num_epochs)total_batch = 0  # 记录进行到多少batchdev_best_loss = float('inf')last_improve = 0  # 记录上次验证集loss下降的batch数flag = False  # 记录是否很久没有效果提升if is_write:writer = SummaryWriter(log_dir="{0}/{1}__{2}__{3}__{4}".format(config.log_path, config.batch_size, config.pad_size,config.learning_rate, time.strftime('%m-%d_%H.%M', time.localtime())))for epoch in range(config.num_epochs):print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs))for i, (trains, labels) in enumerate(train_iter):outputs = model(trains)model.zero_grad()loss = Loss(outputs, labels)loss.backward()optimizer.step()if total_batch % 100 == 0:# 每多少轮输出在训练集和验证集上的效果true = labelspredic = Predict(outputs)train_oe = OneError(outputs, true)train_acc, train_pre, train_hl = APH(true.data.cpu().numpy(), predic.data.cpu().numpy())dev_acc, dev_pre, dev_hl, dev_oe, dev_loss = evaluate(config, model, dev_iter)if dev_loss < dev_best_loss:dev_best_loss = dev_losstorch.save(model.state_dict(), config.save_path)improve = '*'last_improve = total_batchelse:improve = ''time_dif = get_time_dif(start_time)msg = 'Iter: {0:>6}, Train=== Loss: {1:>6.2}, Acc: {2:>6.2%}, Pre: {3:>6.2%}, HL: {4:>5.2} OE: {' \'5:>6.2%}, Val=== Loss: {6:>5.2}, Acc: {7:>6.2%}, Pre: {8:>6.2%}, HL: {9:>5.2}, ' \'OE: {10:>6.2%}, Time: {11} {12} 'print(msg.format(total_batch, loss.item(), train_acc, train_pre, train_hl, train_oe,dev_loss, dev_acc, dev_pre, dev_hl, dev_oe, time_dif, improve))if is_write:writer.add_scalar('loss/train', loss.item(), total_batch)writer.add_scalar("acc/train", train_acc, total_batch)writer.add_scalar("pre/train", train_pre, total_batch)writer.add_scalar("oe/train", train_oe, total_batch)writer.add_scalar("hamming loss/train", train_hl, total_batch)writer.add_scalar("loss/dev", dev_loss, total_batch)writer.add_scalar("acc/dev", dev_acc, total_batch)writer.add_scalar("pre/dev", dev_pre, total_batch)writer.add_scalar("oe/dev", dev_oe, total_batch)writer.add_scalar("hamming loss/dev", dev_hl, total_batch)model.train()total_batch += 1if total_batch - last_improve > config.require_improvement:# 验证集loss超过1000batch没下降,结束训练print("No optimization for a long time, auto-stopping...")flag = Truebreakscheduler.step()  # 学习率衰减if flag:breakif is_write:writer.close()return test(config, model, test_iter)

  需要解释的几点:

  1、bert模型采用AdamW做优化,不同层要设置不同的权重衰减值;

  2、writer这个变量主要是做数据可视化的,参考博客:【深度学习】pytorch使用tensorboard可视化实验数据。

[4] 评估与测试

def test(config, model, test_iter):# testmodel.load_state_dict(torch.load(config.save_path))model.eval()start_time = time.time()test_acc, test_pre, test_rec, test_hl, test_loss, test_report = evaluate(config, model, test_iter,test=True)msg = 'Test Loss: {0:>5.2},  Test Acc: {1:>6.2%}, Test Pre: {2:>6.2%}, Test HL: {3:>5.2}, Test OE: {4:>6.2%}'print(msg.format(test_loss, test_acc, test_pre, test_rec, test_hl))print("Precision, Recall and F1-Score...")print(test_report)time_dif = get_time_dif(start_time)print("Time usage:", time_dif)return test_loss, test_acc, test_pre, test_rec, test_hldef evaluate(config, model, data_iter, test=False):model.eval()loss_total = 0predict_all = []labels_all = []with torch.no_grad():for texts, labels in data_iter:outputs = model(texts)oe = OneError(outputs.data.cpu(), labels.data.cpu())loss = Loss(outputs, labels)loss_total += losslabels = labels.data.cpu().numpy()predic = Predict(outputs.data)labels_all = np.append(labels_all, labels)predict_all = np.append(predict_all, predic.numpy())labels_all = labels_all.reshape(-1, config.num_classes)predict_all = predict_all.reshape(-1, config.num_classes)acc, pre, hl = APH(labels_all, predict_all)if test:report = metrics.classification_report(labels_all, predict_all, target_names=config.class_list, digits=3)return acc, pre, hl, oe, loss_total / len(data_iter), reportreturn acc, pre, hl, oe, loss_total / len(data_iter)

[5] 运行主程序run.py

if __name__ == '__main__':"""配置参数dataSet     : 数据集名称. required.model_name  : 模型名称. required. 可选值['bert']is_write    : 是否开启tensorboard的记录绘图模式. 可选值[False, True]"""M = ['bert','bert_RNN','bert_RCNN','bert_DPCNN']I = [False, True]dataSet = 'Reuters-21578'is_write = I[0]for model_name in M:x = import_module('models.' + model_name)config = x.Config(dataSet)# 设置numpy的随机种子,以使得结果是确定的np.random.seed(1)# 为CPU设置种子用于生成随机数,以使得结果是确定的torch.manual_seed(1)# 为当前GPU设置随机种子,以使得结果是确定的torch.cuda.manual_seed_all(1)# 保证每次结果一样torch.backends.cudnn.deterministic = Truestart_time = time.time()print("Loading data...")train_data, dev_data, test_data = build_dataset(config)train_iter = build_iterator(train_data, config)dev_iter = build_iterator(dev_data, config)test_iter = build_iterator(test_data, config)time_dif = get_time_dif(start_time)print("Time usage:", time_dif)# trainmodel = x.Model(config).to(config.device)print(model.parameters)print(f'The model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters')train(config, model, train_iter, dev_iter, test_iter, is_write)

  代码还是比较好懂的,但是还是有一个整体能运行起来的项目体验更佳。

【BERT-多标签文本分类实战】之七——训练-评估-测试与运行主程序相关推荐

  1. 【BERT-多标签文本分类实战】之五——BERT模型库的挑选与Transformers

    ·请参考本系列目录:[BERT-多标签文本分类实战]之一--实战项目总览 ·下载本实战项目资源:>=点击此处=< [1] BERT模型库   从BERT模型一经Google出世,到tens ...

  2. 【BERT-多标签文本分类实战】之二——BERT的地位与名词术语解释

    ·请参考本系列目录:[BERT-多标签文本分类实战]之一--实战项目总览 ·下载本实战项目资源:>=点击此处=< [注]本篇将从宏观上介绍bert的产生和在众多模型中的地位,以及与bert ...

  3. 7个Bert变种模型baseline在7个文本分类数据集上训练和测试

    引入和代码项目简介 https://github.com/songyingxin/Bert-TextClassification 模型有哪些? 使用的模型有下面七个 BertOrigin, BertC ...

  4. 文本分类实战(十)—— BERT 预训练模型

    1 大纲概述 文本分类这个系列将会有十篇左右,包括基于word2vec预训练的文本分类,与及基于最新的预训练模型(ELMo,BERT等)的文本分类.总共有以下系列: word2vec预训练词向量 te ...

  5. Bert模型做多标签文本分类

    Bert模型做多标签文本分类 参考链接 BERT模型的详细介绍 图解BERT模型:从零开始构建BERT (强推)李宏毅2021春机器学习课程 我们现在来说,怎么把Bert应用到多标签文本分类的问题上. ...

  6. bert 是单标签还是多标签 的分类_搞定NLP领域的“变形金刚”!教你用BERT进行多标签文本分类...

    大数据文摘出品 来源:medium 编译:李雷.睡不着的iris.Aileen 过去的一年,深度神经网络的应用开启了自然语言处理的新时代.预训练模型在研究领域的应用已经令许多NLP项目的最新成果产生了 ...

  7. bert 是单标签还是多标签 的分类_搞定NLP领域的“变形金刚”!手把手教你用BERT进行多标签文本分类...

    大数据文摘出品 来源:medium 编译:李雷.睡不着的iris.Aileen 过去的一年,深度神经网络的应用开启了自然语言处理的新时代.预训练模型在研究领域的应用已经令许多NLP项目的最新成果产生了 ...

  8. bert 是单标签还是多标签 的分类_标签感知的文档表示用于多标签文本分类(EMNLP 2019)...

    原文: Label-Specific Document Representation for Multi-Label Text Classification(EMNLP 2019) 多标签文本分类 摘要: ...

  9. 文本分类实战(三)—— charCNN模型

    1 大纲概述 文本分类这个系列将会有十篇左右,包括基于word2vec预训练的文本分类,与及基于最新的预训练模型(ELMo,BERT等)的文本分类.总共有以下系列: word2vec预训练词向量 te ...

最新文章

  1. docker设置国内镜像源
  2. 分散mysql的写入压力_缓解MySQL写入压力和主从延迟的尝试
  3. mysql utf8mb4 php_MySQL设置utf8mb4编码_MySQL
  4. HDU2023 求平均成绩【入门】
  5. 【Java】函数使用
  6. Qt控件背景图片自适应
  7. C语言 ASCII码字符表
  8. 视频格式转换库--libyuv的简介与编译
  9. Android中@GuardedBy
  10. 测试小故事48:想当然
  11. Exp8 web基础 20164323段钊阳
  12. 用计算机的坏处反方,电脑的坏处辩论会
  13. 基于SVM算法的人脸表情识别
  14. android开发关机代码,android代码实现关机
  15. spoon无法初始化至少一个步骤_通俗易懂:8大步骤图解注意力机制
  16. 1.1 Introduction中 Consumers官网剖析(博主推荐)
  17. Redis集群原理与容器化部署集群
  18. 服务通信:自定义srv文件以及服务端的编写
  19. 新增FacesetEnhancer(脸图增强器) DeepFaceLab更新至2019.12.26
  20. 用地图说话 在商业分析与演示中运用Excel数据地图 全彩

热门文章

  1. Jetson Xavier NX 学习(二)更换源
  2. 51单片机利用液晶制作一个时钟
  3. cmd命令操作MySQL数据库
  4. win11系统之win11亮点
  5. 墨水屏可视化超高频电子标签技术优势与应用解决方案
  6. *p++,(*p)++,*++p,++*p辨析
  7. [python123]元音字母逆序
  8. jsp基本语法表单提交方式
  9. Windows11专业版KMS命令激活(不需要激活工具)
  10. 剑网三哪个网站服务器人多,《剑网3缘起》这么火?服务器快挤爆了!