阈值调优是数据科学中一个重要且必要的步骤。它与应用程序领域密切相关,并且需要一些领域内的知识作为参考。在本文中将演示如何通过阈值调优来提高模型的性能。

用于分类的常用指标

一般情况下我们都会使用准确率accuracy来评价分类的性能,但是有很多情况下accuracy 不足以报告分类模型的性能,所以就出现了很多其他的指标:精确度Precision、召回率Recall、F1 分数F1 score和特异性Specificity。除此以外,还有 ROC 曲线、ROC AUC 和 Precision-Recall 曲线等等。

让我们首先简单解释这些指标和曲线的含义:

精确度Precision:所有正例中真正正例的数量。P=TP/(TP+FP)

召回率Recall:正例数超过真正例数加上假负例数。R=TP/(TP+FN)

F1 分数F1 score:Precision 和 Recall 之间的调和平均值。

特异性Specificity:真负例的数量超过真负例的数量加上假正例的数量。Spec=TN(TN+FP)

(ROC) 曲线:该曲线显示了真正例率和假正例率之间的权衡。代表模型的性能。

ROC曲线下面积(AUC):ROC曲线下面积。如果这个面积等于 1,我们就有了一个完美的分类器。如果它等于 0.5,那么就是一个随机的分类器。

Precision-Recall曲线:这条曲线显示了不同阈值下的精度和召回值。它用于可视化 Precision 和 Recall 之间的权衡。

一般来说,我们必须考虑所有这些指标和曲线。为了将这些内容显示在一起查看,这里定义了一个方法:

 def make_classification_score(y_test, predictions, modelName):tn, fp, fn, tp = confusion_matrix(y_test, predictions).ravel() # ravel() used to convert to a 1-D arrayprec=precision_score(y_test, predictions)rec=recall_score(y_test, predictions)f1=f1_score(y_test, predictions)acc=accuracy_score(y_test, predictions)# specificityspec=tn/(tn+fp)score = {'Model': [modelName], 'Accuracy': [acc], 'f1': [f1], 'Recall': [rec], 'Precision': [prec], \'Specificity': [spec], 'TP': [tp], 'TN': [tn], 'FP': [fp], 'FN': [fn], 'y_test size': [len(y_test)]}df_score = pd.DataFrame(data=score)return df_score

“预测概率”技巧

当我们测试和评估模型时,将预测的 Y 与测试集中的 Y 进行比较。但是这里不建议使用 model.predict(X_test) 方法,直接返回每个实例的标签,而是直接返回每个分类的概率。例如sklearn 提供的 model.predict_proba(X_test) 的方法来预测类概率。然后我们就可以编写一个方法,根据决策阈值参数返回每个实例的最终标签。

 def probs_to_prediction(probs, threshold):pred=[]for x in probs[:,1]:if x>threshold:pred.append(1)else:pred.append(0)return pred

如果设置thresh = 0.5 那么则和调用 model.predict(X_test) 方法得到的结果是相同的,但是使用概率我们可以测试不同的阈值的性能表现。

如果改变阈值则会改变模型的性能。这里可以根据应用程序领域选择一个阈值来最大化重要的度量(通常是精度或召回率),比如在kaggle的比赛中经常会出现thresh = 0.4xx的情况。

选择重要的度量

最大化的重要指标是什么呢?如何确定?

在二元分类任务中,我们的模型会出现两种类型的错误:

第一类错误:预测Y为True,但它实际上是False。也称为假正例错误。

第二类错误:预测Y为False,但它实际上是True。也称为假负例错误。

错误分类实例的数量决定了模型的好坏。但这些错误并不同等重要,对于不用的领域有着不同的要求,比如医学的检测和金融的风控中,需要尽量减小假负例也就是避免第二类错误,需要最小化假负例的数量,那么最大化的重要指标是召回率。

同理,如果要避免第一类错误,我们需要最小化假正例的数量,所以最大化的重要指标是精度。

为了最大化指标,我们可以移动阈值,直到我们在所有指标之间达成良好的平衡,这时就可以使用Precision-Recall曲线,当然也可以使用ROC曲线。

但是要说明的是,我们不能最大化所有指标,因为通过指标的定义就能看到这是不可能的。

阈值优化

假设我们正在处理一个二元分类任务的逻辑回归模型。我们已经进行了训练、超参数调优和测试阶段。该模型已经过交叉验证。也就是说,基本上能做的事情我们都已经做了,但是还是希望能够有一些其他的方式来优化模型,那么则可以试试调整模型的阈值。

对于sklearn来说使用model.predict_proba(X_test)方法来获得类概率,如果使用神经网络的化一般都会输出的是每个类的概率,所以我们这里以sklearn为例,使用这个概率值:

  • 计算ROC AUC,它等于0.9794
  • 计算并绘制ROC曲线
  • 计算并绘制精度-召回率曲线

下面的代码块表示这些步骤:

 def probs_to_prediction(probs, threshold):pred=[]for x in probs[:,1]:if x>threshold:pred.append(1)else:pred.append(0)return pred# getting predicted probability valuesprobability = model.predict_proba(X_test)# calculate ROC AUC score. AUC = 0.9794print("Logit: ROC AUC = %.4f" % roc_auc_score(y_test, probability[:, 1]))# calculate and plot the ROC curvemodel_fpr, model_tpr, _ = roc_curve(y_test, probability[:, 1])plt.plot(model_fpr, model_tpr, marker='.', label='Logit')plt.xlabel('False Positive Rate')plt.ylabel('True Positive Rate (recall)')plt.legend()plt.title("ROC Curve")plt.show()# calculate and plot the Precision-Recall curvemodel_precision, model_recall, thresholds = precision_recall_curve(y_test, probability[:, 1])plt.plot(model_recall, model_precision, marker='.', label='Logit')plt.xlabel('Recall')plt.ylabel('Precision')plt.legend()plt.title('Precision-Recall Curve')plt.show()

下图的曲线。可以看到模型的性能很好。

在本例中,假设在我们的实际应用中FP的成本> FN的成本,所以选择一个阈值在不降低召回率的情况下最大化精度。使用Precision-Recall曲线来对一个可能的阈值进行初始选择。在下面的代码中,绘制了带有候选阈值的Precision-Recall曲线。

 plt.plot(model_recall, model_precision, marker='.', label='Logit')plt.plot(model_recall[43000], model_precision[43000], "ro", label="threshold")plt.xlabel('Recall')plt.ylabel('Precision')plt.legend()plt.title('Precision and Recall values for a chosen Threshold')plt.show()

这样就可以使用选定的阈值来获得最终的分类标签并计算性能指标。并且可以多次进行选择不同阈值进行对比。

 print("Threshold value = %.4f" % thresholds[43000])# results with the chosen thresholdpredictions = probs_to_prediction(probability, thresholds[43000])make_classification_score(y_test, predictions, "logit, custom t")

下图中可以看到,所选的阈值以召回率为代价来最大化精度。根据我们应用的决策阈值,相同的模型可以表现出一些不同的性能。

通过调整阈值并进行结果的对比,一旦对结果满意,模型就可以投入到生产中了。

总结

为分类模型选择最重要的评价指标并不容易。这种选择通常与应用程序领域有关,必须考虑错误分类的代价。在某些情况下,可能有必要咨询领域专家确定哪些错误代表最大的风险。

模型的行为很大程度上受到阈值选择的影响,我们可以应用不同的技术来评估模型并调优阈值以获得预期的结果。

https://avoid.overfit.cn/post/81f1646e48c341358391a9a1d3a2dcfd

作者:Edoardo Bianchi

使用阈值调优改进分类模型性能相关推荐

  1. Java虚拟机这一块 —— JVM 调优和深入了解性能优化

    JVM 调优和深入了解性能优化 JVM 调优的本质 GC 调优原则 调优的原则 目的 GC 调优 调优步骤 日志分析 阅读 GC 日志 -XX:+UseSerialGC -XX:+UseParNewG ...

  2. 分类模型性能评价指标:混淆矩阵、F Score、ROC曲线与AUC面积、PR曲线

    以二分类模型为例:二分类模型最终需要判断样本的结果是1还是0,或者说是positive还是negative. 评价分类模型性能的场景: 采集一个称之为测试集的数据集: 测试集的每一个样本由特征数据及其 ...

  3. 【转载】软件性能测试分析与调优实践之路-性能分析调优思想与调优技术总结

    本文主要阐述软件性能测试中的一些调优思想和技术,节选自作者新书<软件性能测试分析与调优实践之路>部分章节归纳. 一.  性能分析与调优思想 1.性能分析调优模型 性能测试除了为获取性能指标 ...

  4. 软件性能测试分析与调优实践之路-性能分析调优思想与调优技术总结

    来源:https://www.cnblogs.com/laoqing/p/13660768.html 本文主要阐述软件性能测试中的一些调优思想和技术,节选自作者新书<软件性能测试分析与调优实践之 ...

  5. R语言使用randomForest包构建随机森林模型(Random forests)、使用importance函数查看特征重要度、使用table函数计算混淆矩阵评估分类模型性能、包外错误估计OOB

    R语言使用randomForest包中的randomForest函数构建随机森林模型(Random forests).使用importance函数查看特征重要度.使用table函数计算混淆矩阵评估分类 ...

  6. R语言使用rpart包构建决策树模型、使用prune函数进行树的剪枝、交叉验证预防过拟合、plotcp可视化复杂度、rpart.plot包可视化决策树、使用table函数计算混淆矩阵评估分类模型性能

    R语言使用rpart包构建决策树模型.使用prune函数进行树的剪枝.使用10折交叉验证选择预测误差最低的树来预防过拟合.plotcp可视化决策树复杂度.rpart.plot包可视化最终决策树.使用t ...

  7. R语言使用R基础安装中的glm函数构建乳腺癌二分类预测逻辑回归模型、分类预测器(分类变量)被自动替换为一组虚拟编码变量、summary函数查看检查模型、使用table函数计算混淆矩阵评估分类模型性能

    R语言使用R基础安装中的glm函数构建乳腺癌二分类预测逻辑回归模型(Logistic regression).分类预测器(分类变量)被自动替换为一组虚拟编码变量.summary函数查看检查模型.使用t ...

  8. java 内存调优_JVM内存模型以及性能调优

    JVM 内存模型 JVM.png 程序计数器 程序计数器是一块较小的内存空间,可以看作是当前线程所执行的字节码的行号指示器.分支.循环.跳转.异常处理.线程恢复等基础功能都需要依赖这个计数器来完成. ...

  9. 回归和分类模型性能评估指标MSE,MAE,PR,ROC,AUC

    文章目录 0. 模型评估是什么,为什么 1. 不同类型问题的评估指标 1.1 回归问题 1.2 分类问题 1.2.1 准确率和错误率 1.2.2 精确率和召回率 1.2.3 PR曲线图 1.2.4 F ...

最新文章

  1. Memcache 安装与命令 (windows 64bit)
  2. .NET : VS 2008中的一个转换器
  3. ZendStudio10.6.1如何安装最新的集成svn小工具?
  4. 博客目录(python相关)
  5. Bookmarklet
  6. mysql5.7.17主从_mysql5.7.17主从同步配置
  7. 微课|玩转Python轻松过二级:第3章课后习题解答6
  8. 1.12 改善你的模型的表现
  9. pta 习题集5-19 列车厢调度
  10. 关于Visual Studio .NET 2010最近的发布情况
  11. SpringBoot整合Jersey2.x实现文件上传API
  12. 企业级oracle视频教程,企业级Oracle数据库高可用性(OracleDataGuard)DBA培训视频全集...
  13. 精益标准工时软件VIOOVI:没有标准工时,别谈精益改善!
  14. 封装设计 SLC、MLC和TLC
  15. 如何用Python下载百度指数的数据
  16. Aspose.Barcode创建二维码应用代码示例
  17. sudoku_solver :数独解题器
  18. 正则表达式:回车和换行的区别
  19. 个人发展分析:SWOT
  20. Linux安装水星MW150US

热门文章

  1. 推荐一款快速学习的神器
  2. U盘安装CentOS7查看u盘设备名称的命令
  3. 假茅台酒比例你知道吗
  4. 现代 cmake (cmake 3.x) 操作大全
  5. IDEA 设置启动端口号
  6. mysql opkg源_如何修改opkg源
  7. POJ - 1990 MooFest
  8. UFS之Power Mode
  9. 20170918深圳东方博雅笔试
  10. 高通SDX55平台:5G速率问题排查分析方法