Content

  • 1. Introduction
  • 2. Data
  • 3. 数据预处理
  • 4. Random Forest
  • 5. 模型评估
  • 6. Feature Importance Analysis
    • 6.1 决策树可视化
    • 6.2 Permutation importance
    • 6.3 Partial Dependence Plots
  • 7. 后记

记得有一次去面试,那个公司的HR聊天说,她感觉程序员面试那是面真功夫,会就会,不会装也没用。从这里想开来,还真是,码农学再多理论,终究是要去码砖的。我呢就是原来机器学习和深度学习的理论学的多,实践反而少,所以感觉有时候做事情就慢了些。现在趁着还有些闲工夫,就找一些项目做做,由简单到复杂,慢慢来吧。

欢迎大家收藏我的个人博客 KK’s Notes,2019.4.5 刚搞成功,接下来CSDN和 KK’s Notes 同时更新,各位看官大佬多多指教。

1. Introduction

这个项目来自于kaggle。项目主要是利用患者的个人信息和检查数据,利用机器学习方法来诊断该患者收否患疾病,并且尝试对识别结果作出解释。这个项目虽然简单但将机器学习的全流程和常用预处理和分析方法都涉及到了,我做完一遍还是有很多收获。以下操作皆在 Jubyter notebook 下以 Python 进行的。

主要使用的技术:

  • Random Forest
  • Feature Importance Analysis: Permutation importance
  • Feature Importance Analysis: Partial Dependence Plots

2. Data

Data from:https://www.kaggle.com/ronitf/heart-disease-uci/downloads/heart.csv/
About Data:下载好数据之后直接打开看一看。

import pandas as pd
import numpy as np
data = pd.read_csv('data/heart.csv')
data.info()

Output:

可以看到总共有303条数据以及13个特征和1个标签,数据没有缺失项。接下看下前十个数据。

data.head(10)

Output:

这13个特征的含义分别是:

age: 年龄
sex:该人的性别(1=男性,0=女性)
cp:胸痛经历(值1:典型心绞痛,值2:非典型心绞痛,值3:非心绞痛,值4:无症状)
trestbps:该人的静息血压(入院时为mm Hg)
chol:人体胆固醇测量单位为mg/dl
fbs:该人的空腹血糖(> 120mg/dl,1=true; 0= f=alse)
restecg:静息心电图测量(0=正常,1=有ST-T波异常,2=按Estes标准显示可能或明确的左心室肥厚)
thalach:达到了该人的最大心率
exang:运动诱发心绞痛(1=是; 0=否)
oldpeak:运动相对于休息引起的ST段压低('ST’与ECG图上的位置有关)
slope:峰值运动ST段的斜率(值1:上升,值2:平坦,值3:下降)
ca:主要血管数量(0-3)
thal:称为地中海贫血的血液疾病(1=正常; 2=固定缺陷; 3=可逆缺陷)
target:心脏病(0=不,1=是)

为了更好的理解数据,我们应该提前查一下每个特征的含义,以及医学上该特征和心脏病的关系。具体这里不再赘述。

3. 数据预处理

这里为了方便后续做心脏病诊断中影响因素分析即Feature Importance Analysis(还是觉得用英文更能表达意思),将部分数值型特征进行转换:

data.loc[data.sex == 1, 'sex'] = 'male'
data.loc[data['sex'] == 0, 'sex'] = 'female'data.loc[data['cp'] == 1, 'cp'] = 'typical'
data.loc[data['cp'] == 2, 'cp'] = 'atypical'
data.loc[data['cp'] == 3, 'cp'] = 'no_pain'
data.loc[data['cp'] == 4, 'cp'] = 'no_feel'data.loc[data['fbs'] == 1, 'fbs'] = 'higher than 120 mg/dl'
data.loc[data['fbs'] == 0, 'fbs'] = 'lower than 120 mg/dl'data.loc[data['restecg'] == 0, 'restecg'] = 'normal'
data.loc[data['restecg'] == 1, 'restecg'] = 'ST-T wave abnormality'
data.loc[data['restecg'] == 2, 'restecg'] = 'left ventricular hypertrophy'data.loc[data['exang'] == 1, 'exang'] = 'true'
data.loc[data['exang'] == 0, 'exang'] = 'false'data.loc[data['slope'] == 1, 'slope'] = 'up'
data.loc[data['slope'] == 2, 'slope'] = 'flat'
data.loc[data['slope'] == 3, 'slope'] = 'down'data.loc[data['thal'] == 1, 'thal'] = 'normal'
data.loc[data['thal'] == 2, 'thal'] = 'fixed defect'
data.loc[data['thal'] == 3, 'thal'] = 'reversable defect'

检查下数据情况:

data.describe(include=[np.object])

Output:

可以看到特征thal有4个值,而我们在转换时只转换了3个。实际上thal存在2个缺失值用0补齐的。为了防止数据类型错误,这里做一下类型转换。

data['thal'] = data['thal'].astype('object')

再看下数据:

data.head()

Output:

模型的训练肯定需要数值型特征。这里对特征进行Onehot编码。

data = pd.get_dummies(data, drop_first=True)
data.head()

Output:

(由于我还不知道在用markdown编辑时怎么显示运行结果,这里用的是截图,只能截取一部分,还有特征没有截取出来)
数据预处理部分就到此为止,接下来上模型。

4. Random Forest

对于 Random Forest 的原理这里就不介绍了,网上介绍的文章也很多。废话不多说,直接import package.

from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifierimport matplotlib.pyplot as plt

将数据分成 train_data 和 test_data 2个集合,二者比例为8:2。

train_x, test_x, train_y, test_y = train_test_split(data.drop(columns='target'),data['target'],test_size=0.2,random_state=10)

简单的画个图调个参。这里 Random Forest 主要的参数有基学习器决策树的最大深度(这里依据经验选5)、基学习器个数 n_estimators。这里基学习器选用CART。

train_score = []
test_score = []for n in range(1, 100):model = RandomForestClassifier(max_depth=5,n_estimators=n,criterion='gini')model.fit(train_x, train_y)train_score.append(model.score(train_x, train_y))test_score.append(model.score(test_x, test_y))

训练完,把train和test上的accuracy随基学习器个数的变化画成图。

x_axis = [i for i in range(1, 100)]fig, ax = plt.subplots()
ax.plot(x_axis, train_score[:99])
ax.plot(x_axis, test_score[:99], c="r")
plt.xlim([0, 100])
plt.ylim([0.0, 1.0])
plt.rcParams['font.size'] = 12
plt.xlabel('n_estimators')
plt.ylabel('accuracy')
plt.grid(True)

Output:

可以看到大概是n_estimators=14的时候效果最好,train和test上的accuracy分别是0.9463,0.8361。看上去没有那么差。

5. 模型评估

训练完模型,用ROC曲线来评估下模型的效果。ROC曲线事宜FPR和TPR分别为横纵轴作出的曲线,其和坐标轴围成的面积越大,说明模型效果越好。具体评判标准见下文。说一下几个概念:

  • TPR: 真正例率,表示所有真正为正例的样本被正确预测出来的比例,等同于Recall
  • FNR: 假负例率,FNR = 1 - TPR
  • FPR: 假正例率,表示所有负例中被预测为正例的比例。
  • TNR: 真负例率,TNR = 1 - FPR

好吧,我也快晕了。
接下来计算一下正例和负例的recall

from sklearn.metrics import confusion_matrix
from sklearn.metrics import auc, roc_curve# 混淆矩阵
confusion_m = confusion_matrix(test_y, pred_y)
print confusion_m

Output:

[[29  6][ 4 22]]
total = confusion_m.sum()
tpr = float(confusion_m[0][0]) / (confusion_m[0][0] + confusion_m[1][0])
tnr = float(confusion_m[1][1]) / (confusion_m[1][1] + confusion_m[0][1])
print tpr, tnr

Output:

0.878787878788 0.785714285714

Just so so!!

画ROC曲线图:

pred_y = model.predict(test_x)  # 预测结果
pred_prob_y = model.predict_proba(test_x)[:, 1]  # 为正例的概率
fpr_list, tpr_list, throsholds = roc_curve(test_y, pred_prob_y)# 画图
fig, ax = plt.subplots()
ax.plot(fpr_list, tpr_list)
ax.plot([0, 1], [0, 1], transform=ax.transAxes, ls="--", c="r")
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.0])
plt.rcParams['font.size'] = 12
plt.title('roc curve')
plt.xlabel('fpr')
plt.ylabel('tpr')
plt.grid(True)

Output:

前文说了,ROC曲线和坐标轴围成的面积越大,说明模型效果越好。这个面积就叫 AUC .根据AUC的值,可参考下面的规则评估模型:

  • 0.90 - 1.00 = excellent
  • 0.80 - 0.90 = good
  • 0.70 - 0.80 = fair
  • 0.60 - 0.70 = poor
  • 0.50 - 0.60 = fail

看看我们训练模型的AUC

auc(fpr_list, tpr_list)

Output:

0.9032967032967033

OK, working well!

6. Feature Importance Analysis

训练完模型,我们希望能从模型里得到点什么, 比如说哪些特征对模型结果贡献率比较大,是不是意味着这些影响因素在实际心脏病诊断中也是很重要对参考,或者说还能发现一些现有医学没有发现的发现。所有接下来我们做的是一件很有意思的事。

6.1 决策树可视化

如果我没记错的话, 根据决策树的原理,越先分裂的特征越重要。那么下面对决策树进行可视化,看看它到底做了什么。

from sklearn.tree import export_graphviz# 输出 feature_name
estimator = model.estimators_[1]
features = [i for i in train_x.columns]# 0 —> no disease,1 —> disease
train_y_str = train_y.astype('str')
train_y_str[train_y_str == '0'] = 'no disease'
train_y_str[train_y_str == '1'] = 'disease'
train_y_str = train_y_str.values

sklearn 真是个好东西,你能想到对功能他都有。下面用 sklearn 的 export_graphviz 对决策树进行可视化。

export_graphviz(estimator, out_file='tree.dot', feature_names = features,class_names = train_y_str,rounded = True, proportion = True, label='root',precision = 2, filled = True)

生成对这个 tree.dot 文件还不能直接看,网上查了一下,把它输出来看看。

import pydotplus
from IPython.display import Image
img = pydotplus.graph_from_dot_file('tree.dot')
#img.write_pdf('tree.pdf') #输出成PDF
Image(img.create_png())

Output:

实际上这张图就解释来决策树的生成过程。一般我们认为最先分裂的特征越重要,但是从这张图我们并不能很直观的看出特征的重要性。

6.2 Permutation importance

我们换一个工具—Permutation importance. 其原理是依次打乱test_data中其中一个特征数值的顺序,其实就是做shuffle,然后观察模型的效果,下降的多的说明这个特征对模型比较重要。

import eli5
from eli5.sklearn import PermutationImportanceperm = PermutationImportance(model, random_state=20).fit(test_x, test_y)
eli5.show_weights(perm, feature_names=test_x.columns.tolist())

Output:

一目了然,一切尽在不言中。还是说俩句吧,绿色越深表示正相关越强,红色越深表示负相关越强。
实际上我发现改变 PermutationImportance 的参数 random_state 的值结果变化挺大的,不过还是有几个特征位次变化不大,结果还是具有参考意义。

6.3 Partial Dependence Plots

我们试试另一个工具—Partial Dependence Plots. 其原理和 Permutation importance 有点类似,当它判断一个特征对模型的影响时,对于所有样本,将该特征依次取该特征的所有取值,观察模型结果的变化。先画图,再根据图解释一下。

from pdpbox import pdp, info_plotstotal_features = train_x.columns.values.tolist()
feature_name = 'oldpeak'
pdp_dist = pdp.pdp_isolate(model=model, dataset=test_x, model_features=total_features, feature=feature_name)pdp.pdp_plot(pdp_dist, feature_name)
plt.show()

Output:

上图的纵坐标是模型相对于base model 的变化,横坐标是该特征的所有取值,实线表示相对于base model 的变化的平均值,蓝色阴影表示置信度。oldpeak表示运动相对于休息引起的ST段压低,可以看到其取值越大,患心脏病的可能性越低。不知道这个结果可不可信,我觉得需要医学知识作支撑。

又试了几个特征:

Sex:

上图说明男性比女性患心脏病的概率要低些,网上查了一下,还真是这样。

Age:

上图表示60岁以上老人心脏病高发,这个和现有理论相符。

接下来看一下 2D Partial Dependence Plots.

inter = pdp.pdp_interact(model=model, dataset=test_x, model_features=total_features, features=['oldpeak', 'age'])pdp.pdp_interact_plot(pdp_interact_out=inter, feature_names=['oldpeak', 'age'], plot_type='contour')
plt.show()

Output:

这个图一开始没看到,后来仔细看了Partial Dependence Plots 的说明文档才搞明白。图中颜色从浅到深表示患心脏病概率降低,以最深的那个紫色为例,oldpeak > 3.0 && 45 < age < 65 时,患病概率最低,图中黄色部分表示,oldpeak < 0.25 && ( age < 45 || age > 65 ) 时,患病概率最高。

7. 后记

实际上本项目的数据是非常小的,其结果的可靠性也是值得怀疑的。但是通过这个项目,去经历机器学习项目的完整过程,却能学到很多东西。重要的是过程,更重要的是举一反三。该项目还引入了2个很有趣的Feature Importance Analysis的方法,对于我来说是新知识,也算是学到了。

这一篇到这里结束了,期待下一篇。

kaggle实战——What Causes Heart Disease?相关推荐

  1. Kaggle实战:点击率预估

    版权声明:本文出自程世东的知乎,原创文章,转载请注明出处:Kaggle实战--点击率预估. 请安装TensorFlow1.0,Python3.5 项目地址: chengstone/kaggle_cri ...

  2. kaggle实战—泰坦尼克(五、模型搭建-模型评估)

    kaggle实战-泰坦尼克(一.数据分析) kaggle实战-泰坦尼克(二.数据清洗及特征处理) kaggle实战-泰坦尼克(三.数据重构) kaggle实战-泰坦尼克(四.数据可视化) kaggle ...

  3. kaggle实战—泰坦尼克(三、数据重构)

    kaggle实战-泰坦尼克(一.数据分析) kaggle实战-泰坦尼克(二.数据清洗及特征处理) kaggle实战-泰坦尼克(三.数据重构) kaggle实战-泰坦尼克(四.数据可视化) kaggle ...

  4. kaggle实战—泰坦尼克(二、数据清洗及特征处理)

    kaggle实战-泰坦尼克(一.数据分析) kaggle实战-泰坦尼克(二.数据清洗及特征处理) kaggle实战-泰坦尼克(三.数据重构) kaggle实战-泰坦尼克(四.数据可视化) kaggle ...

  5. kaggle实战笔记_1.数据处理

    kaggle实战笔记_1.数据处理 数据处理的重要性比模型更重要 如果正负样本是1:100的话,直接拿去做建模,问题是非常大的,如果其评判标准为accuracy的话,如果把任何一个样本都判定为负样本的 ...

  6. kaggle实战—泰坦尼克(四、数据可视化)

    kaggle实战-泰坦尼克(一.数据分析) kaggle实战-泰坦尼克(二.数据清洗及特征处理) kaggle实战-泰坦尼克(三.数据重构) kaggle实战-泰坦尼克(四.数据可视化) kaggle ...

  7. Kaggle实战——点击率预估

    <深度学习私房菜:跟着案例学Tensorflow>作者 版权声明:本文出自程世东的知乎,原创文章,转载请注明出处:Kaggle实战--点击率预估. 请安装TensorFlow1.0,Pyt ...

  8. Kaggle实战入门:泰坦尼克号生还预测(基础版)

    Kaggle实战入门:泰坦尼克号生还预测 1. 加载数据 2. 特征工程 3. 模型训练 4. 模型部署 泰坦尼克号(Titanic),又称铁达尼号,是当时世界上体积最庞大.内部设施最豪华的客运轮船, ...

  9. Kaggle实战入门:泰坦尼克号生还预测(进阶版)

    Kaggle实战入门:泰坦尼克号生还预测 1. 加载数据 2. 特征工程 3. 模型训练 4. 模型部署 Kaggle实战入门:泰坦尼克号生还预测(基础版)对机器学习的全流程进行了总体介绍.本文继续以 ...

最新文章

  1. StackOverflow 上面最流行的 7 个 Java 问题!
  2. how to find the tomcat version info on linux
  3. ACL 2018论文解读 | 基于排序思想的弱监督关系抽取选种与降噪算法
  4. C言语实现midpoint euler中点欧拉法解常微分方程(附完整源码)
  5. .cpp 编译成.a或是 .so
  6. Linux 从头学 01:CPU 是如何执行一条指令的?
  7. 7张图讲透Java垃圾回收算法!学妹直呼666!!!
  8. Linux(四):虚拟机Ubuntu 卸载
  9. 少年,这有套《街霸2》AI速成心法,想传授于你……
  10. python配置文件注释_python操作配置文件yaml
  11. dom影像图形成数字地形图_数字地形图等高线怎么生成(海地软件地形图数字化高层数据层怎么选择,在哪里...)...
  12. Ajax技术体系的组成部分
  13. 自动计数报警器c语言程序,自动计数报警器.ppt
  14. 如何设置电脑桌面动态壁纸
  15. 从酷狗的网络红歌说起
  16. 如何实现扫描二维码自动跳转到网页
  17. 解决Invalid HTTP_HOST header: 'xxx.xx.xxx.xxx:8000'. You may need to add 'xxx.xx' to ALLOWED_HOSTS!
  18. 基于OpenCV的二维码和条形码识别
  19. vue mysql 电商,Vue电商项目
  20. 印度社交市场:谁能挑战Facebook们的霸主地位?

热门文章

  1. sql sa 账号被锁定的解决办法
  2. ESP分区和MSR分区是干嘛的?
  3. JavaWeb 简易留言系统
  4. Java Obiect类--------11
  5. 【MTU】Windows/Linux下修改MTU
  6. Android 7.1.2 默认输入法的设置流程分析与修改
  7. 什么是企业oa办公系统登录入口?oa办公系统哪家好?
  8. 光猫灯显示正常但是报651错误解决办法
  9. 移动机器人设计与实践-基础概念汇总
  10. 【SAP消息号KI344】