Stacking算法分析与案例调参实例

Stacking方法是一种分层模型集成框架。以两层为例,首先将数据集分成训练集和测试集,利用训练集训练得到多个初级学习器,然后用初级学习器对测试集进行预测,并将输出值作为下一阶段训练的输入值,最终的标签作为输出值,用于训练次级学习器(通常最后一级使用Logistic回归)。由于两次所使用的训练数据不同,因此可以在一定程度上防止过拟合。

由于要进行多次训练,因此这种方法要求训练数据很多,为了防止发生划分训练集和测试集后,测试集比例过小,生成的次级学习器泛化性能不强的问题,通常在Stacking算法中会使用我们上次讲到的交叉验证法或留一法来进行训练。

用下面的图来举例说明Stacking:

1、首先我们会得到两组数据:训练集和测试集。将训练集分成5份:train1,train2,train3,train4,train5。

2、选定基模型。这里假定我们选择了xgboost, lightgbm 和 randomforest 这三种作为基模型。比如xgboost模型部分:依次用train1,train2,train3,train4,train5作为验证集,其余4份作为训练集,进行5折交叉验证进行模型训练;再在测试集上进行预测。这样会得到在训练集上由xgboost模型训练出来的5份predictions,和在测试集上的1份预测值B1。将这五份纵向重叠合并起来得到A1。
lightgbm和randomforest模型部分同理。

3、三个基模型训练完毕后,将三个模型在训练集上的预测值作为分别作为3个"特征"A1,A2,A3,使用LR模型进行训练,建立LR模型。

4、使用训练好的LR模型,在三个基模型之前在测试集上的预测值所构建的三个"特征"的值(B1,B2,B3)上,进行预测,得出最终的预测类别或概率。

在sklearn库中暂时还没有支持Stacking算法的,因此使用mlxtend来实现(pip install mlxtend)

1.简单堆叠3折CV分类

from sklearn import datasets
from sklearn.model_selection import cross_val_score
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.ensemble import RandomForestClassifier
from mlxtend.classifier import StackingCVClassifieriris = datasets.load_iris()
X, y = iris.data[:, 1:3], iris.target
RANDOM_SEED = 42clf1 = KNeighborsClassifier(n_neighbors=1)
clf2 = RandomForestClassifier(random_state=RANDOM_SEED)
clf3 = GaussianNB()
lr = LogisticRegression()# Starting from v0.16.0, StackingCVRegressor supports
# `random_state` to get deterministic result.
sclf = StackingCVClassifier(classifiers=[clf1, clf2, clf3],  # 第一层分类器meta_classifier=lr,   # 第二层分类器random_state=RANDOM_SEED)print('3-fold cross validation:\n')for clf, label in zip([clf1, clf2, clf3, sclf], ['KNN', 'Random Forest', 'Naive Bayes','StackingClassifier']):scores = cross_val_score(clf, X, y, cv=3, scoring='accuracy')print("Accuracy: %0.2f (+/- %0.2f) [%s]" % (scores.mean(), scores.std(), label))

运行结果

3-fold cross validation:Accuracy: 0.91 (+/- 0.01) [KNN]
Accuracy: 0.95 (+/- 0.01) [Random Forest]
Accuracy: 0.91 (+/- 0.02) [Naive Bayes]
Accuracy: 0.93 (+/- 0.02) [StackingClassifier]

画出决策边界

from mlxtend.plotting import plot_decision_regions
import matplotlib.gridspec as gridspec
import itertoolsgs = gridspec.GridSpec(2, 2)
fig = plt.figure(figsize=(10,8))
for clf, lab, grd in zip([clf1, clf2, clf3, sclf], ['KNN', 'Random Forest', 'Naive Bayes','StackingCVClassifier'],itertools.product([0, 1], repeat=2)):clf.fit(X, y)ax = plt.subplot(gs[grd[0], grd[1]])fig = plot_decision_regions(X=X, y=y, clf=clf)plt.title(lab)
plt.show()


2. 使用概率作为元特征
使用第一层所有基分类器所产生的类别概率值作为meta-classfier的输入。需要在StackingClassifier 中增加一个参数设置:use_probas = True。
另外,还有一个参数设置average_probas = True,那么这些基分类器所产出的概率值将按照列被平均,否则会拼接。
例如:
基分类器1:predictions=[0.2,0.2,0.7]
基分类器2:predictions=[0.4,0.3,0.8]
基分类器3:predictions=[0.1,0.4,0.6]

1)若use_probas = True,average_probas = True,
则产生的meta-feature 为:[0.233, 0.3, 0.7]
2)若use_probas = True,average_probas = False,
则产生的meta-feature 为:[0.2,0.2,0.7,0.4,0.3,0.8,0.1,0.4,0.6]

from sklearn import datasets
from sklearn.model_selection import cross_val_score
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.ensemble import RandomForestClassifier
from mlxtend.classifier import StackingCVClassifieriris = datasets.load_iris()
X, y = iris.data[:, 1:3], iris.targetclf1 = KNeighborsClassifier(n_neighbors=1)
clf2 = RandomForestClassifier(random_state=1)
clf3 = GaussianNB()
lr = LogisticRegression()sclf = StackingCVClassifier(classifiers=[clf1, clf2, clf3],use_probas=True,meta_classifier=lr,random_state=42)print('3-fold cross validation:\n')for clf, label in zip([clf1, clf2, clf3, sclf],['KNN','Random Forest','Naive Bayes','StackingClassifier']):scores = cross_val_score(clf, X, y, cv=3, scoring='accuracy')print("Accuracy: %0.2f (+/- %0.2f) [%s]"% (scores.mean(), scores.std(), label))

运行结果

3-fold cross validation:Accuracy: 0.91 (+/- 0.01) [KNN]
Accuracy: 0.95 (+/- 0.01) [Random Forest]
Accuracy: 0.91 (+/- 0.02) [Naive Bayes]
Accuracy: 0.95 (+/- 0.02) [StackingClassifier]

集成学习笔记12-Stacking算法分析与案例调参实例相关推荐

  1. [go学习笔记.第十一章.项目案例] 2.客户信息管理系统

    一.基本介绍 1.需求说明 项目需求分析 1.模拟实现基于文本界面的 < 客户信息管理软件 > 2.该软件实现对客户对象的插入.修改和删除(用切片实现),并能够打印客户明细表 2.界面设计 ...

  2. [go学习笔记.第十一章.项目案例] 1.家庭收支记账软件项目

    一.基本介绍 1.项目开发流程说明 2.项目需求说明 目标: 模拟实现一个基于文本界面的<<家庭记账软件>> 掌握初步的编程技巧和调试技巧 主要涉及以下知识点 : (1).局部 ...

  3. Linux学习笔记12——配置ftp、squid、Tomcat、Samba、MySQL主从

    Linux学习笔记12 Linux学习笔记12 配置FTP服务 配置pure-ftpd 开机启动 上传下载文件 配置vsftpd CentOS 70安装配置Vsftp服务器 搭好vsftp之后出现55 ...

  4. 数据结构学习笔记:变位词侦测案例

    数据结构学习笔记:变位词侦测案例 通过字符串变位词侦测问题可以很好地了解具有不同数量级的算法.变位词,就是两个字符串构成要素完全相同,但是要素的排列顺序不同.比如,heart与earth.python ...

  5. golang学习笔记12 beego table name `xxx` repeat register, must be unique 错误问题

    golang学习笔记12 beego table name `xxx` repeat register, must be unique 错误问题 今天测试了重新建一个项目生成新的表,然后复制到旧的项目 ...

  6. HALCON 20.11:深度学习笔记(12)---语义分割

    HALCON 20.11:深度学习笔记(12)--- 语义分割 HALCON 20.11.0.0中,实现了深度学习方法. 本章解释了如何使用基于深度学习的语义分割,包括训练和推理阶段. 通过语义分割, ...

  7. 台大李宏毅Machine Learning 2017Fall学习笔记 (12)Why Deep?

    台大李宏毅Machine Learning 2017Fall学习笔记 (12)Why Deep? 本博客整理自: http://blog.csdn.net/xzy_thu/article/detail ...

  8. Kotlin学习笔记12——数据类和密封类

    Kotlin学习笔记12--数据类和密封类 前言 数据类 在类体中声明的属性 复制 componentN 解构声明 密封类 尾巴 前言 上一篇,我们学习了Kotlin中的拓展,今天继续来学习Kotli ...

  9. R语言小白学习笔记12—概率分布

    R语言小白学习笔记12-概率分布 笔记链接 学习笔记12-概率分布 12.1 正态分布 12.2 二项分布 12.3 泊松分布 12.4 其他分布 笔记链接 学习笔记1-R语言基础. 学习笔记2-高级 ...

最新文章

  1. MacOS 下使用 intellij IDEA 将git上传项目到 Github
  2. mysql常见问题处理-插入数据error code:1206
  3. 数据结构(数据库)设计规范
  4. Python四道面试题
  5. loadrunner接口性能测试分享
  6. 经典SQL回顾之晋级篇
  7. vue3 @/cli脚手架搭建项目
  8. 计算机科学基础内容摘抄,科学网-上计算机课,不接触计算机----日记摘抄(161)-武夷山的博文...
  9. Linux的学习之路grep命令
  10. 陆琪:年薪十万凭什么不能开…
  11. 苹果手机性能测试用是么软件,怎么检测iPhone手机性能
  12. 云服务器什么配置才够用?
  13. 龙族幻想最新东京机器人位置_龙族幻想藤原智坐标位置一览 藤原智任务攻略...
  14. GB/T 36624-2018《可鉴别的加密机制》笔记——5. 机制4:加密+MAC
  15. 《七周七并发模型》笔记
  16. 以太坊和Hyperledger Fabric之间的差异
  17. 小猪的C语言快速入门系列(三)
  18. Python解析wireshark所捕获的数据报
  19. 用文件记录游戏最高分【C语言】
  20. position的值, relative和absolute分别是相对于谁进行定位,你真的知道吗?

热门文章

  1. WPF 项目开发入门(五)ListView列表组件 与 Expander组件
  2. 【小程序】获取数组长度
  3. linux查看硬盘插槽_linux查看硬盘接口方法有哪些
  4. 移动硬盘或U盘提示:文件或目录损坏且无法读取的解决方法
  5. 倍福FTP文件下载方式
  6. 【Android笔记104】Android之壁纸管理器(WallpaperManager)的使用
  7. Sate210-F 金手指核心板引脚说明
  8. 华为技能鉴定java_华为Java笔试题
  9. 在 ML2 中 enable local network - 每天5分钟玩转 OpenStack(79)
  10. Pycharm调试错误:连接 Python 调试器失败 Socket closed