数据源

本章介绍另一种分类算法:决策树,比起其他算法决策树最主要的一个优点诗决策过程是机器和人都能看懂的,我们使用机器学习到的模型就能完成预测任务,另一个优点是他可以处理多种不同类型的特征。
我们这章使用的数据请在文章开头的数据源中python数据挖掘/Chapter4中的文件

这一章的数据诗NBA2013-2014赛季的比赛数据,这是一个CSV文件,我们将它读取到pandas中看一下

In [1]: import numpy as np                                                      In [2]: import pandas as pd                                                     In [3]: dataset = pd.read_csv('leagues_NBA_2014_gam...: es_games.csv')                                                          In [4]: dataset.head()
Out[4]: Date        NaN       Visitor/Neutral  ...  PTS.1 NaN.1  Notes
0  Tue Oct 29 2013  Box Score         Orlando Magic  ...     97   NaN    NaN
1  Tue Oct 29 2013  Box Score  Los Angeles Clippers  ...    116   NaN    NaN
2  Tue Oct 29 2013  Box Score         Chicago Bulls  ...    107   NaN    NaN
3  Wed Oct 30 2013  Box Score         Brooklyn Nets  ...     98   NaN    NaN
4  Wed Oct 30 2013  Box Score         Atlanta Hawks  ...    118   NaN    NaN

现在这个数据有一些问题

  1. 第一列Date日期是字符串
  2. 表头需要优化

那么我们从新搞一哈
幸亏pandas可以将很多种字符串日期转化为标准日期对象

In [5]: dataset = pd.read_csv('/Users/gn/scikit--learn/data/leagues_NBA_2014_gam...: es_games.csv',parse_dates = ['Date'],skiprows=[0,])               In [6]: dataset.columns = ['Data','Scire Type','Visitor Team','VisitorPts','Home...:  Team','HomePts','OT','Notes']   In [7]: dataset.head()
Out[7]: Data Scire Type          Visitor Team  VisitorPts            Home Team  HomePts   OT Notes
0 2013-10-29  Box Score         Orlando Magic          87       Indiana Pacers       97  NaN   NaN
1 2013-10-29  Box Score  Los Angeles Clippers         103   Los Angeles Lakers      116  NaN   NaN
2 2013-10-29  Box Score         Chicago Bulls          95           Miami Heat      107  NaN   NaN
3 2013-10-30  Box Score         Brooklyn Nets          94  Cleveland Cavaliers       98  NaN   NaN
4 2013-10-30  Box Score         Atlanta Hawks         109     Dallas Mavericks      118  NaN   N

现在看起来是不是好多了
由于数据中不包含胜负数据,需要我们将比分转化为直观的胜负
现在需要创建一些特征用于数据挖掘,我们使用上一场在主场的胜负和在客场的胜负来判断

In [11]: dataset['HomeWin'] = dataset['VisitorPts'] < dataset['HomePts']                                                                            In [12]: from collections import defaultdict                                                                                                        In [13]: won_last = defaultdict(int)                                                                                                                In [14]: dataset["HomeLastWin"] = False                                                                                                             In [15]: dataset["VisitorLastWin"] = False                                                                                                          In [16]: for index, row in dataset.iterrows(): ...:     home_team = row["Home Team"] ...:     visitor_team = row["Visitor Team"] ...:     row["HomeLastWin"] = won_last[home_team] ...:     row["VisitorLastWin"] = won_last[visitor_team] ...:     dataset.loc[index] = row  ...:     won_last[home_team] = row["HomeWin"] ...:     won_last[visitor_team] = not row["HomeWin"] ...:                                                                                                                                            In [17]: dataset.loc[20:25]
Out[17]: Data Scire Type            Visitor Team  VisitorPts           Home Team  HomePts   OT Notes  HomeWin HomeLastWin VisitorLastWin
20 2013-11-01  Box Score         Milwaukee Bucks         105      Boston Celtics       98  NaN   NaN    False       False          False
21 2013-11-01  Box Score              Miami Heat         100       Brooklyn Nets      101  NaN   NaN     True       False          False
22 2013-11-01  Box Score     Cleveland Cavaliers          84   Charlotte Bobcats       90  NaN   NaN     True       False           True
23 2013-11-01  Box Score  Portland Trail Blazers         113      Denver Nuggets       98  NaN   NaN    False       False          False
24 2013-11-01  Box Score        Dallas Mavericks         105     Houston Rockets      113  NaN   NaN     True        True           True
25 2013-11-01  Box Score       San Antonio Spurs          91  Los Angeles Lakers       85  NaN   NaN    False       False           True

上面一段代码首先根据得分从新创建一列直观的主场队伍胜负
之后创建了上次主场胜负和客场胜负两列并赋予false
我们遍历这个DataFrame,通过在字典中创建相关队伍上一次比赛胜负,来给每一个主场胜负和客场胜负赋值。
由于每个队伍在第一场比赛的时候是没有前一场比赛记录的,所以我们看20行以后的数据。
我们就使用上次在主场和客场的胜负来训练

决策树

决策树诗一种有监督的机器学习算法,他看起来就像由一系列节点组成的流程图,其中位于上层节点的值决定下一步走向哪个节点
scikit-learn库实现了分类回归树算法,并将其作为生成决策树的默认算法,他支持连续型特征值和类别型特征值。
下面我们看一下代码

In [18]: from sklearn.tree import DecisionTreeClassifier                                                                                            In [19]: clf = DecisionTreeClassifier(random_state = 14)                                                                                            In [20]: from sklearn.model_selection import cross_val_score                                                                                        In [21]: X_previouswins = dataset[["HomeLastWin", "VisitorLastWin"]].values ...: y_true = dataset["HomeWin"].values ...: scores = cross_val_score(clf,X_previouswins,y_true,scoring='accuracy') ...: print("使用在最后一次主场和客场胜负结果") ...: print("准确率: {0:.1f}%".format(np.mean(scores) * 100))
/Users/gn/anaconda3/lib/python3.7/site-packages/sklearn/model_selection/_split.py:2053: FutureWarning: You should specify a value for 'cv' instead of relying on the default value. The default value will change from 3 to 5 in version 0.22.warnings.warn(CV_WARNING, FutureWarning)
使用在最后一次主场和客场胜负结果
准确率: 57.8%

我们使用交叉实验得出的结果是57.8%。

但是每个队伍第一次赋值上次主场客场胜负的时候是无法赋值的,现在我们默认为0,胜场默认为1,连胜2场为2 连胜N场为N,我们再来看一下

In [22]: results = dataset                                                                                                                          In [23]: results["HomeWinStreak"] = 0 ...: results["VisitorWinStreak"] = 0 ...: win_streak = defaultdict(int) ...: for index, row in results.iterrows(): ...:     home_team = row["Home Team"] ...:     visitor_team = row["Visitor Team"] ...:     row["HomeWinStreak"] = win_streak[home_team] ...:     row["VisitorWinStreak"] = win_streak[visitor_team] ...:     results.loc[index] = row     ...:     if row["HomeWin"]: ...:         win_streak[home_team] += 1 ...:         win_streak[visitor_team] = 0 ...:     else: ...:         win_streak[home_team] = 0 ...:         win_streak[visitor_team] += 1 ...:                                                                                                                                            In [24]: clf = DecisionTreeClassifier(random_state=14) ...: X_winstreak =  results[["HomeLastWin", "VisitorLastWin", "HomeWinStreak", "VisitorWinStreak"]].values ...: scores = cross_val_score(clf, X_winstreak, y_true, scoring='accuracy') ...: print("使用在最后一次主场和客场胜负结果") ...: print("准确率: {0:.1f}%".format(np.mean(scores) * 100))                                                                                    使用在最后一次主场和客场胜负结果
准确率: 56.1%

差不多的准确度。

下面我们新增一个新的特征值,主队是否通常比对手水平高
我们使用2013赛季的各球队排名来创建这一特征值

In [34]: ladder = pd.read_csv('/Users/gn/scikit--learn/data/leagues_NBA_2013_standings_expanded-standings.csv', encoding='gbk')                     In [35]: ladder.head()
Out[35]: Rk                   Team Overall  Home   Road      E      W     A     C    SE  ...   Post    ≤3    ≥10  Oct   Nov   Dec   Jan   Feb   Mar  Apr
0   1             Miami Heat   66-16  37-4  29-12  41-11   25-5  14-4  12-6  15-1  ...   30-2   9-3   39-8  1-0  10-3  10-5   8-5  12-1  17-1  8-1
1   2  Oklahoma City Thunder   60-22  34-7  26-15   21-9  39-13   7-3   8-2   6-4  ...   21-8   3-6   44-6  NaN  13-4  11-2  11-5   7-4  12-5  6-2
2   3      San Antonio Spurs   58-24  35-6  23-18   25-5  33-19   8-2   9-1   8-2  ...  16-12   9-5  31-10  1-0  12-4  12-4  12-3   8-3  10-4  3-6
3   4         Denver Nuggets   57-25  38-3  19-22  19-11  38-14   5-5  10-0   4-6  ...   24-4  11-7   28-8  0-1   8-8   9-6  12-3   8-4  13-2  7-1
4   5   Los Angeles Clippers   56-26  32-9  24-17   21-9  35-17   7-3   8-2   6-4  ...   17-9   3-5  38-12  1-0   8-6  16-0   9-7   8-5   7-7  7-1[5 rows x 24 columns]In [36]: results["HomeTeamRanksHigher"] = 0 ...: for index, row in results.iterrows(): ...:     home_team = row["Home Team"] ...:     visitor_team = row["Visitor Team"] ...:     if home_team == "New Orleans Pelicans": ...:         home_team = "New Orleans Hornets" ...:     elif visitor_team == "New Orleans Pelicans": ...:         visitor_team = "New Orleans Hornets" ...:     home_rank = ladder[ladder["Team"] == home_team]["Rk"].values[0] ...:     visitor_rank = ladder[ladder["Team"] == visitor_team]["Rk"].values[0] ...:     row["HomeTeamRanksHigher"] = int(home_rank > visitor_rank) ...:     results.ix[index] = row ...:                                                                                                                                            In [37]: X_homehigher =  results[["HomeLastWin", "VisitorLastWin", "HomeTeamRanksHigher"]].values                                                   In [38]: clf = DecisionTreeClassifier(random_state=14)                                                                                              In [39]: scores = cross_val_score(clf, X_homehigher, y_true, scoring='accuracy')                                                                    In [40]: print("使用在最后一次主场和客场胜负结果") ...: print("准确率: {0:.1f}%".format(np.mean(scores) * 100))
使用在最后一次主场和客场胜负结果
准确率: 60.3%

上面一段代码由于2013和2014赛季的名字有些不同,所以要还原名称,我们增加新的特征值后准确率又提升了

下面我们来统计一下两只球队上场比赛的情况,作为另一个特征,虽然球队排名有助于预测,但是某些排名靠后的队伍反而能超越排名靠前的队伍。我们创建的新特征就是上场比赛两个队伍的胜负情况

In [41]: last_match_winner = defaultdict(int)                                                                                                       In [42]: results["HomeTeamWonLast"] = 0                                                                                                             In [43]: for index, row in results.iterrows(): ...:     home_team = row["Home Team"] ...:     visitor_team = row["Visitor Team"] ...:     teams = tuple(sorted([home_team, visitor_team])) ...:     row["HomeTeamWonLast"] = 1 if last_match_winner[teams] == row["Home Team"] else 0 ...:     results.loc[index] = row ...:     winner = row["Home Team"] if row["HomeWin"] else row["Visitor Team"] ...:     last_match_winner[teams] = winner ...:                                                                                                                                            In [44]: X_home_higher =  results[["HomeTeamRanksHigher", "HomeTeamWonLast"]].values                                                                In [45]: clf = DecisionTreeClassifier(random_state=14)                                                                                              In [46]: scores = cross_val_score(clf, X_home_higher, y_true, scoring='accuracy')                                                                   In [47]: print("使用在最后一次主场和客场胜负结果") ...: print("准确率: {0:.1f}%".format(np.mean(scores) * 100))
使用在最后一次主场和客场胜负结果
准确率: 60.6%

上面这一段依旧是先创建一个字典,字典的键是排序后的主队和客队名称。这次使用队伍排名+两队上场比赛胜负这两个特征值得到的准确率为60.6%比之前又稍微提升了一点

最后我们看一下决策树在训练数据量很大的情况下,能否得到有效的分类模型,我们会为决策树添加球队,以监测它是否能整合新增信息。
虽然决策树能够处理特征值为类别型的数据,但是scikit-learn所能实现的决策树算法要求对这类特征进行处理,使用LabelEncoder可以将字符串类型的球队转化为整数,在通过OneHotEncoder将其转换为二进制。

In [48]: from sklearn.preprocessing import LabelEncoder, OneHotEncoder                                                                              In [49]: encoding = LabelEncoder()                                                                                                                  In [50]: encoding.fit(results["Home Team"].values)
Out[50]: LabelEncoder()In [51]: home_teams = encoding.transform(results["Home Team"].values)                                                                               In [52]: visitor_teams = encoding.transform(results["Visitor Team"].values)                                                                         In [53]: X_teams = np.vstack([home_teams, visitor_teams]).T                                                                                         In [54]: onehot = OneHotEncoder()                                                                                                                   In [55]: X_teams = onehot.fit_transform(X_teams).todense()                                                                                          In [56]: clf = DecisionTreeClassifier(random_state=14)                                                                                              In [57]: scores = cross_val_score(clf, X_teams, y_true, scoring='accuracy')                                                                         In [58]: print("准确率: {0:.1f}%".format(np.mean(scores) * 100))
准确率: 60.3%

差不多,但是没有之前的效果好

随机森林

一颗决策树可以学到很复杂的规则,然而可能会导致过拟合问题。解决方法之一就是调整决策树算法,限制他所学到的规则数量,例如将深度限制在三层,只让他学习从全局角度拆分数据集的最佳规则,不去考虑适用面很窄的特定规则。这种折中方案得到的决策树泛化能力强但整体表现稍弱。
为了弥补,我们可以创建多颗决策树,用他们分别进行预测再根据少数服从多数的原则从多个预测结果中选择最终预测结果,这就是随机森林的工作原理。
代码和之前的大同小异

In [59]: from sklearn.ensemble import RandomForestClassifier                                                                                        In [60]: clf = RandomForestClassifier(random_state=14)                                                                                              In [61]: scores = cross_val_score(clf,X_teams,y_true,scoring='accuracy')
/Users/gn/anaconda3/lib/python3.7/site-packages/sklearn/model_selection/_split.py:2053: FutureWarning: You should specify a value for 'cv' instead of relying on the default value. The default value will change from 3 to 5 in version 0.22.warnings.warn(CV_WARNING, FutureWarning)
/Users/gn/anaconda3/lib/python3.7/site-packages/sklearn/ensemble/forest.py:246: FutureWarning: The default value of n_estimators will change from 10 in version 0.20 to 100 in 0.22."10 in version 0.20 to 100 in 0.22.", FutureWarning)
/Users/gn/anaconda3/lib/python3.7/site-packages/sklearn/ensemble/forest.py:246: FutureWarning: The default value of n_estimators will change from 10 in version 0.20 to 100 in 0.22."10 in version 0.20 to 100 in 0.22.", FutureWarning)
/Users/gn/anaconda3/lib/python3.7/site-packages/sklearn/ensemble/forest.py:246: FutureWarning: The default value of n_estimators will change from 10 in version 0.20 to 100 in 0.22."10 in version 0.20 to 100 in 0.22.", FutureWarning)In [62]: print("准确率: {0:.1f}%".format(np.mean(scores) * 100))
准确率: 60.9%

只是更换了分类器我们的准确率又提升了,下面我们把球队排名和最后一次主场胜利这两个特征加进去。

In [63]: X_all = np.hstack([X_home_higher, X_teams])                                                                                                In [64]: clf = RandomForestClassifier(random_state=14)                                                                                              In [65]: scores = cross_val_score(clf, X_all, y_true, scoring='accuracy')
/Users/gn/anaconda3/lib/python3.7/site-packages/sklearn/model_selection/_split.py:2053: FutureWarning: You should specify a value for 'cv' instead of relying on the default value. The default value will change from 3 to 5 in version 0.22.warnings.warn(CV_WARNING, FutureWarning)
/Users/gn/anaconda3/lib/python3.7/site-packages/sklearn/ensemble/forest.py:246: FutureWarning: The default value of n_estimators will change from 10 in version 0.20 to 100 in 0.22."10 in version 0.20 to 100 in 0.22.", FutureWarning)
/Users/gn/anaconda3/lib/python3.7/site-packages/sklearn/ensemble/forest.py:246: FutureWarning: The default value of n_estimators will change from 10 in version 0.20 to 100 in 0.22."10 in version 0.20 to 100 in 0.22.", FutureWarning)
/Users/gn/anaconda3/lib/python3.7/site-packages/sklearn/ensemble/forest.py:246: FutureWarning: The default value of n_estimators will change from 10 in version 0.20 to 100 in 0.22."10 in version 0.20 to 100 in 0.22.", FutureWarning)In [66]: print("准确率: {0:.1f}%".format(np.mean(scores) * 100))
准确率: 61.1%

加入新特征后准确率又提升了

现在我们使用GridSearchCV类来搜索最佳参数

In [67]: from sklearn.model_selection import GridSearchCV                                                                                           In [68]: parameter_space = { ...:                    "max_features": [2, 10, 'auto'], ...:                    "n_estimators": [100,], ...:                    "criterion": ["gini", "entropy"], ...:                    "min_samples_leaf": [2, 4, 6], ...:                    }                                                                                                                       In [69]: clf = RandomForestClassifier(random_state=14)                                                                                              In [70]: grid = GridSearchCV(clf, parameter_space)                                                                                                  In [71]: grid.fit(X_all, y_true)
/Users/gn/anaconda3/lib/python3.7/site-packages/sklearn/model_selection/_split.py:2053: FutureWarning: You should specify a value for 'cv' instead of relying on the default value. The default value will change from 3 to 5 in version 0.22.warnings.warn(CV_WARNING, FutureWarning)
print("准确率: {0:.1f}%".format(grid.best_score_ * 100))
print(grid.best_estimator_)
Out[71]:
GridSearchCV(cv='warn', error_score='raise-deprecating',estimator=RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',max_depth=None, max_features='auto', max_leaf_nodes=None,min_impurity_decrease=0.0, min_impurity_split=None,min_samples_leaf=1, min_samples_split=2,min_weight_fraction_leaf=0.0, n_estimators='warn', n_jobs=None,oob_score=False, random_state=14, verbose=0, warm_start=False),fit_params=None, iid='warn', n_jobs=None,param_grid={'max_features': [2, 10, 'auto'], 'n_estimators': [100], 'criterion': ['gini', 'entropy'], 'min_samples_leaf': [2, 4, 6]},pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',scoring=None, verbose=0)In [72]: print("准确率: {0:.1f}%".format(grid.best_score_ * 100))
准确率: 64.2%In [73]: print(grid.best_estimator_)
RandomForestClassifier(bootstrap=True, class_weight=None, criterion='entropy',max_depth=None, max_features=2, max_leaf_nodes=None,min_impurity_decrease=0.0, min_impurity_split=None,min_samples_leaf=6, min_samples_split=2,min_weight_fraction_leaf=0.0, n_estimators=100, n_jobs=None,oob_score=False, random_state=14, verbose=0, warm_start=False)

给出了我们得到最高准确率的一些参数

python数据挖掘(4.决策树)相关推荐

  1. 【python数据挖掘课程】十九.鸢尾花数据集可视化、线性回归、决策树花样分析

    这是<Python数据挖掘课程>系列文章,也是我这学期上课的部分内容.本文主要讲述鸢尾花数据集的各种分析,包括可视化分析.线性回归分析.决策树分析等,通常一个数据集是可以用于多种分析的,希 ...

  2. 【Python数据挖掘课程】四.决策树DTC数据分析及鸢尾数据集分析

    今天主要讲述的内容是关于决策树的知识,主要包括以下内容:         1.分类及决策树算法介绍         2.鸢尾花卉数据集介绍         3.决策树实现鸢尾数据集分析         ...

  3. python数据挖掘学习笔记】十九.鸢尾花数据集可视化、线性回归、决策树花样分析

    #2018-04-05 16:57:26 April Thursday the 14 week, the 095 day SZ SSMR python数据挖掘学习笔记]十九.鸢尾花数据集可视化.线性回 ...

  4. 《python数据挖掘入门与实践》决策树预测nba数据集

    前言: 学到决策树预测球队输赢时,按照书中网址去下载数据集,无奈怎么也没下载成功.即使下载了excel文件也是破损的.咱可是学了python的银,那好吧,我就把它爬取下来.(资源在下面) 代码: '' ...

  5. Python数据挖掘入门与实践---用决策树预测获胜球队

    数据集来源:1.  2013-14 NBA Schedule and Results 2.2013年 NBA 赛季排名情况 参考书籍:<Python数据挖掘入门与实践> 1.加载数据集: ...

  6. 带你入门Python数据挖掘与机器学习(附代码、实例)

    作者:韦玮 来源:Python爱好者社区 本文共7800字,建议阅读10+分钟. 本文结合代码实例待你上手python数据挖掘和机器学习技术. 本文包含了五个知识点: 1. 数据挖掘与机器学习技术简介 ...

  7. python数据挖掘是什么_python数据挖掘是什么

    数据挖掘(data mining,简称DM),是指从大量的数据中,通过统计学.人工智能.机器学习等方法,挖掘出未知的.且有价值的信 息和知识的过程. python数据挖掘常用模块 numpy模块:用于 ...

  8. 【python数据挖掘课程】二十四.KMeans文本聚类分析互动百科语料

    这是<Python数据挖掘课程>系列文章,也是我上课内容及书籍中的一个案例.本文主要讲述文本聚类相关知识,包括中文分词.数据清洗.特征提取.TF-IDF.KMeans聚类等步骤.本篇文章为 ...

  9. 【Python数据挖掘课程】九.回归模型LinearRegression简单分析氧化物数据

    这篇文章主要介绍三个知识点,也是我<数据挖掘与分析>课程讲课的内容.同时主要参考学生的课程提交作业内容进行讲述,包括:         1.回归模型及基础知识:         2.UCI ...

  10. Python数据挖掘和机器学习

    -----------------------------2017.8.9--------------------------------- 先占个坑 在接下来的一个半月里(即从现在到十一) 我将结合 ...

最新文章

  1. 把项目放到码云上,通过git 进行项目管理
  2. 《OpenCV3编程入门》学习笔记6 图像处理(三)形态学滤波(1):腐蚀与膨胀
  3. Halcon与QT的联合编程(1)
  4. 软件架构自学笔记--大学学的软件工程为什么感觉很“虚”
  5. Java static作用
  6. 完整的OTT直播点播系统都有哪些功能?
  7. cross-domain policy file
  8. python多线程输出_萌新python多线程
  9. AliOS Things声源定位应用演示 1
  10. “新型冠状病毒国家科技资源服务系统”入选全球15项世界互联网领先科技成果...
  11. python3 定时任务_Python3.x:定时任务实现方式
  12. 【爬虫剑谱】一卷2章 软件篇-EdgeDriver的安装及配置
  13. 数据库SQL中的分钟表示应该使用MI(非常重要的一个问题,以前一直认为和java中一样,用mm就可以表示);校对规则(查询时区分大小写)
  14. 浅析Thinkphp框架中运用phprpc扩展模式
  15. FastFDS--文件服务系统
  16. Ubuntu安装Robo3T(Studio3T Free)
  17. [MySQL学习] Innodb锁系统(4) Insert/Delete 锁处理及死锁示例分析
  18. 基于java的网络聊天室
  19. 东拉西扯01世界的沧海桑田
  20. 数值范围_想顺产的话,这个数值最好别超出此范围,否则顺产的分娩风险大!...

热门文章

  1. Constraintlayout使用问题归纳
  2. 用Photosho创建一劳永逸的图层样式
  3. 检查型异常 与 非检查型异常
  4. 今日指数项目之ETL数据业务开发(需求说明)【八】
  5. Kali Linux解决有线网络托管图标不见的解决办法
  6. PySide 路在何方?
  7. linux固态硬盘解锁,linux – 如何使用hdparm解锁ssd磁盘?
  8. qq录屏快捷键大全,玩转录制就这么简单(干货)
  9. 鲲鹏Devkit开发框架插件工具课堂笔记—第一讲:鲲鹏开发框架插件工具
  10. 发现,朋友圈样式开发总结