在前面决策树的介绍中,我们使用ID3算法来构建决策树;这里我们使用CART算法来构建回归树和模型树。ID3算法是每次选取当前最佳的特征来分割数据,并按照该特征的所有可能取值来区分。比如,如果一个特征有4种取值,那么数据将被切分成4份。很明显,该算法不适用于标签值为连续型的数据。

CART算法使用二元切分法来处理连续型变量,即每次把数据集切分成左右两份。

回归树

回归树使用CART算法来构建树,使用二元切分法来分割数据,其叶节点包含单个值。创建回归树的函数createTree()的伪代码大致如下:

找到最佳的待切分特征:

如果该节点不能再分,将该节点存为叶节点

执行二元切分

在左子树调用createTree()方法

在右子树调用createTree()方法

创建回归树的过程与决策树类似,只是切分方法不同。同时,在计算数据的无序程度时,决策树是使用香农熵的方法,而我们的标签值是连续值,不能用该方法。那么,怎么计算连续型数值的混乱度呢?首先,计算所有数据的均值,然后计算每条数据的值到均值的差值,再求平方和,即我们用总方差的方法来度量连续型数值的混乱度。在回归树中,叶节点包含单个值,所以总方差可以通过均方差乘以数据集中样本点的个数来得到。

下面,将计算连续型数值混乱度的代码提供如下:

#计算分割代价
def spilt_loss(left,right):  #总方差越小,说明数据混乱度越小loss=0.0left_size=len(left)#print 'left_size:',left_sizeleft_label=[row[-1] for row in left]right_size=len(right)right_label=[row[-1] for row in right]loss += var(left_label)*left_size + var(right_label)*right_sizereturn loss

得到叶节点预测值的代码:

#决定输出标签(取出叶节点数据的标签值,计算平均值)
def decide_label(data):output=[row[-1] for row in data]return mean(output)

模型树

模型树与回归树的差别在于:回归树的叶节点是节点数据标签值的平均值,而模型树的节点数据是一个线性模型(可用最简单的最小二乘法来构建线性模型),返回线性模型的系数W,我们只要将测试数据X乘以W便可以得到预测值Y,即Y=X*W。所以该模型是由多个线性片段组成的。

同样,给出叶节点预测值及计算待分割数据集混乱度的代码:

#生成叶节点
def decide_label(dataSet):ws,X,Y = linearModel(dataSet)return ws#计算模型误差
def spilt_loss(dataSet):ws,X,Y = linearModel(dataSet)yat = X * wsreturn sum(power(yat-Y,2))#模型预测数据
def modelTreeForecast(ws,dataRow):data = mat(dataRow)n = shape(data)[1]X = mat(ones((1,n)))X[:,1:n] = data[:,0:n-1]return X*ws

那么,如何比较回归树与模型树那种模型更好呢?一个比较客观的方法是计算预测值与实际值相关系数。该相关系数可以通过调用NumPy库中的命令corrcoef(yHat,y.rowvar=0)来求解,其中yHat是预测值,y是目标变量的实际值。

剪枝

通过降低树的复杂度来避免过拟合的过程称为剪枝。对树的剪枝分为预剪枝和后剪枝。一般地,为了寻求最佳模型可以同时使用这两种剪枝技术。

预剪枝:在选择创建树的过程中,我们限制树的迭代次数(即限制树的深度),以及限制叶节点的样本数不要过小,设定这种提前终止条件的方法实际上就是所谓的预剪枝。周志华的西瓜书中有对预剪枝的方法做具体描述,感兴趣的同学可以了解一下。因为我只是通过提前终止条件的方法来实现预剪枝,这种方法比较简单,不做具体描述。

后剪枝:使用后剪枝方法需要将数据集分为测试集和训练集。用测试集来判断将这些叶节点合并是否能降低测试误差,如果是的话将合并。

直接上代码:

'''后剪枝过程'''
#判断是否为字典
def isTree(obj):return (type(obj).__name__=='dict')def getMean(tree):  #将叶节点的训练数据的标签值的平均值作为该节点的预测值if isTree(tree['right']):tree['right'] = getMean(tree['right'])if isTree(tree['left']):tree['left'] = getMean(tree['left'])return (tree['left']+tree['right'])/2.0#执行后剪枝(具体来说,就是将测试集按照之前生成的树一步步分类到叶节点,计算相应的标签值与叶节点预测值的总方差,如果剪枝后方差变小,则执行剪枝)
def prune(testData,tree):if len(testData)==0:return getMean(tree)if (isTree(tree['left']) or isTree(tree['right'])):    #判断tree['left']和tree['right']是否为字典,如果为字典则进行数据划分lSet,rSet = data_spilt(testData,tree['index'],tree['value'])  #划分数据集if isTree(tree['left']):            #如果tree['left']是字典,则执行prune()函数进行递归,直到tree['left']是叶节点时结束递归,往下继续执行函数tree['left'] = prune(lSet,tree['left'])if isTree(tree['right']):           #在tree['left']执行递归的基础上继续递归,这样可以取到所有左右两边的叶节点的值tree['right'] = prune(lSet,tree['right'])if not isTree(tree['left']) and not isTree(tree['right']):  #如果tree['left']和tree['right']都不是字典,执行下面操作lSet,rSet = data_spilt(testData,tree['index'],tree['value'])  #分割数据集left_value = [row[-1] for row in lSet]    #取出左数据集的节点值right_value = [row[-1] for row in rSet]   #取出右数据集的节点值if tree['left'] is None or tree['right'] is None:   #如果出现tree['left']或tree['right']为None时,返回树,不执行剪枝操作return treeelse:errorNoMerge = sum(power(left_value-tree['left'],2)) + sum(power(right_value-tree['right'],2))   #计算没剪枝时测试集的标签值与叶节点的预测值的总方差treeMean = (tree['left'] + tree['right'])/2.0testSet_value = [row[-1] for row in testData]errorMerge = sum(power(testSet_value-treeMean,2))  #计算剪枝后测试集的标签值与叶节点的预测值的总方差if errorMerge < errorNoMerge:   #如果剪枝后的方差小于剪枝前,则执行剪枝;否则返回,不剪枝。print 'merging'return treeMeanelse:return treeelse :return tree

以上,便是我在学习过程中对回归树,模型树,树剪枝的一些总结。

树模型之回归树,模型树,树剪枝相关推荐

  1. ❤️解决非线性回归问题的机器学习方法总结:多项式线性模型、广义线性(GAM)模型、回归树模型、支持向量回归(SVR)模型

    文章目录 前言 多项式回归模型 概念解释: sklearn实现多项式回归模型: 广义线性可加(GAM)模型 概念解释: pygam实现广义线性可加模型: GAM模型的优点与不足: 回归树模型 概念解释 ...

  2. Python实现Stacking回归模型(随机森林回归、极端随机树回归、AdaBoost回归、GBDT回归、决策树回归)项目实战

    说明:这是一个机器学习实战项目(附带数据+代码+文档+视频讲解),如需数据+代码+文档+视频讲解可以直接到文章最后获取. 1.项目背景 Stacking通常考虑的是异质弱学习器(不同的学习算法被组合在 ...

  3. 伟景行citymaker-----01.javascript打开本地模型CEP,加载目录树,加载要素类

    以下所有代码基于 CityMaker_IE_Plugin_vConnect8.0.171127.exe 版本 该版本只能使用IE打开,建议使用IE11 下载代码案例 1.打开cep模型代码 1.1  ...

  4. CART树分类、回归、剪枝实现

    决策树ID3,C4.5是多叉树,CART树是一个完全二叉树,CART树不仅能完成分类也能实现回归功能,所谓回归指的是目标是一个连续的数值类型,比如体重.身高.收入.价格等,在介绍ID3,C4.5其核心 ...

  5. 决策树之建立一棵树(代码模板)防止过拟合、剪枝参数

    建立一棵树 1.导入需要的算法库和模块 from sklearn import tree from sklearn.datasets import load_wine from sklearn.mod ...

  6. [湖南集训]更为厉害 树上主席树-以树深度为下下标建立主席树

    题意题解: 首先对于树上某个点a来说,假设点b是a的祖先(也就是在a的上面),那么答案很好计算,也就是min(k,dep[a]−1)∗(size[a]−1)min(k,dep[a]-1)*(size[ ...

  7. P1276 校门外的树(增强版)(线段树)(校门三部曲)难度⭐⭐⭐

    校门三部曲,总算完结了!完结散花! 难度呈阶梯状,都可以用线段树解决. 第一部 P1047 校门外的树(线段树优化)难度⭐⭐ 第二部 P1276 校门外的树(增强版)(线段树)校门三部曲难度⭐⭐⭐ 第 ...

  8. c语言孩子兄弟法存储一棵树,数据结构(C语言版)---树

    1.树:n个结点的有限集,n=0时为空树. 1)特点: (1)有且仅有一个特定的称为根的结点. (2)有若干个互不相交的子树,这些子树本身也是一棵树. (3)树的根结点没有前驱结点,除根结点外的所有结 ...

  9. hdu3966 树链剖分点权模板+线段树区间更新/树状数组区间更新单点查询

    点权树的模板题,另外发现树状数组也是可以区间更新的.. 注意在对链进行操作时方向不要搞错 线段树版本 #include<bits/stdc++.h> using namespace std ...

最新文章

  1. MSDN 教程短片 WPF 16(Path路径)
  2. 关于DSP的GPIO的输入输出设置
  3. 秦川团队《科学》刊发研究:新冠感染恒河猴康复后不会再感染
  4. SpringCloud Zuul(四)之工作原理
  5. 将WildFly绑定到其他IP地址或多宿主上的所有地址
  6. 如何利用openssl来计算md4, md5, sha1, sha256, sha384, sha512等常用哈希值?
  7. 数字图像处理--图像颜色
  8. 渗透学习笔记--基础篇--sql注入(数字型)
  9. Zhang-Suen细化算法讲解及实现
  10. Madwifi Mad coding:自底向上分析associated_sta的更新过程 —— RSSI和MACADDR等信息获取的底层原理...
  11. FTP、Telnet、SMTP、POP3等服务的名称及端口号和各种数据库的端口号
  12. Python项目实战化:爬取堆糖网研究所美好生活照
  13. User does not have the ‘LOCK TABLES‘ privilege required to obtain a consistent snapshot by preventin
  14. 没有学历,四步进Google
  15. 双基因突变患者_一例 Kallmann 综合征患者双基因突变分析
  16. js基础系列之函数调用与this
  17. Matlab透视变换
  18. ESP32 ESP-IDF增加自定义components 注意事项
  19. 计算机基础及Java语法
  20. ubuntu 18使用国内版firefox

热门文章

  1. C语言编程的书写规则,关于C语言编程书写规范的规则和建议.doc
  2. 使用Pattern、Matcher类和正则表达式从一段文字中获取其中的手机号码
  3. 相比之下,在美国看病简直是噩梦
  4. Swagger UI文件上传
  5. 操作系统面试问答题大全
  6. 安装pip和设置pycharm
  7. 无法将“webpack”项识别为 cmdlet、函数、脚本文件或可运行程序的名称...
  8. deepin更新启动项_deepin删除多余启动项和添加引导项
  9. 计算机专业刚学应该自学什么,晋中计算机专业主要学什么课?
  10. java 输出字符串变量_java打印字符串变量