实现RandomForest 随机森林

基于python的sklearn机器学习 类实现

平台
python3.7 Anaconda sklearn库及配套库
# -*- coding: utf-8 -*-
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.ensemble import BaggingClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score
from sklearn.metrics import cohen_kappa_score
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix    # 生成混淆矩阵函数
import matplotlib.pyplot as plt
import matplotlib as mpl
from sklearn.preprocessing import StandardScaler
from sklearn.externals import joblib#保存模型
import itertools
class Ctrain_forest:'''调用sklearn 实现Random Forest功能:画混淆矩阵输入数据实现训练保存模型到指定位置调用模型实现预测'''def plot_confusion_matrix(self,cm, classes,normalize=False,title='Confusion matrix',cmap=plt.cm.Blues,path="maxtix"):"""画混淆矩阵This function prints and plots the confusion matrix.Normalization can be applied by setting `normalize=True`.画图函数 输入:cm 矩阵 classes 输入str类型title 名字cmap [图的颜色设置](https://matplotlib.org/examples/color/colormaps_reference.html)"""if normalize:cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]print("Normalized confusion matrix")else:print('Confusion matrix, without normalization')plt.figure(figsize=(11,8))plt.imshow(cm, interpolation='nearest', cmap=cmap)plt.title(title)plt.colorbar()tick_marks = np.arange(len(classes))plt.xticks(tick_marks, classes, rotation=45)plt.yticks(tick_marks, classes)fmt = '.2f' if normalize else 'd'thresh = cm.max() / 2.for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):plt.text(j, i, format(cm[i, j], fmt),horizontalalignment="center",color="white" if cm[i, j] > thresh else "black")# plt.gca().set_xticks(tick_marks, minor=True)# plt.gca().set_yticks(tick_marks, minor=True)# plt.gca().xaxis.set_ticks_position('none')# plt.gca().yaxis.set_ticks_position('none')#plt.grid()# plt.gcf().subplots_adjust(bottom=0.1)# plt.tight_layout()plt.ylabel('True label')plt.xlabel('Predicted label')#解决中文显示plt.rcParams['font.sans-serif']=['SimHei']plt.rcParams['axes.unicode_minus'] = False    plt.savefig(path,dpi=500)  # plt.show()def train_forest(self,x,y,path):"""Random Foeset类输入:x、y以实现训练,path是保存训练过程的路径输出:clf 模型matrix 混淆矩阵dd classifi_reportkappa kappa系数acc_1 模型精度"""X_train,data1x,y_train,data1y = train_test_split(x,y,test_size=0.9,random_state=0)#寻找最优参数depth = np.arange(1,25,4)acc_list = []for d in depth:clf =RandomForestClassifier(bootstrap=True, class_weight="balanced", criterion='gini',max_depth=d*10+1, max_features='auto', max_leaf_nodes=None,min_impurity_decrease=0.0, min_impurity_split=None,min_samples_leaf=3, min_samples_split=3,min_weight_fraction_leaf=0.0, n_estimators=140*2+1, n_jobs=-1,oob_score=False, verbose=0, warm_start=False)clf.fit(X_train, y_train)y_pred_rf = clf.predict(data1x)acc=accuracy_score(data1y, y_pred_rf)acc_list.append(acc)print(accuracy_score(data1y, y_pred_rf))  #整体精度print(cohen_kappa_score(data1y, y_pred_rf))  #Kappa系数#画图mpl.rcParams['font.sans-serif'] = ['SimHei']plt.figure(facecolor='w')plt.plot(depth, acc_list, 'ro-', lw=1)plt.xlabel('随机森林决策树数量', fontsize=15)plt.ylabel('预测精度', fontsize=15)plt.title('随机森林决策树数量和过拟合', fontsize=18)plt.grid(True)plt.savefig(path,dpi=300)#plt.show()y_pred_rf = clf.predict(data1x)print(accuracy_score(data1y, y_pred_rf))  #整体精度#dist=data1y-y_pred_rfprint(cohen_kappa_score(data1y, y_pred_rf))  #Kappa系数matrix=confusion_matrix(data1y, y_pred_rf)kappa=cohen_kappa_score(data1y, y_pred_rf)dd=classification_report(data1y, y_pred_rf)acc_1=accuracy_score(data1y, y_pred_rf)"""# 特征重要性评定rnd_clf = RandomForestClassifier(n_estimators=500, n_jobs=-1)rnd_clf.fit(x, y)for name, score in zip(x, rnd_clf.feature_importances_):print(name, score)""" return clf,matrix,dd,kappa,acc_1def save_model(self,clf,src):"""保存模型到某处clf 模型src 路径"""joblib.dump(clf, src)def get_model_predit(self,data,src):"""调用模型实现预测输入原始数据src 模型路径返回预测值"""getsavemodel=joblib.load(src)predity=getsavemodel.predict(pd.DataFrame(data))return predity

运行结果:


基于python sklearn的 RandomForest随机森林 类实现相关推荐

  1. Spark 和 Python.sklearn:使用随机森林计算 feature_importance 特征重要性

    前言 在使用GBDT.RF.Xgboost等树类模型建模时,往往可以通过feature_importance 来返回特征重要性,本文以随机森林为例介绍其原理与实现.[ 链接:机器学习的特征重要性究竟是 ...

  2. matlab 随机森林算法_(六)如何利用Python从头开始实现随机森林算法

    博客地址:https://blog.csdn.net/CoderPai/article/details/96499505 点击阅读原文,更好的阅读体验 CoderPai 是一个专注于人工智能在量化交易 ...

  3. sklearn实战之随机森林

    sklearn实战系列: (1) sklearn实战之决策树 (2) sklearn实战之随机森林 (3) sklearn实战之数据预处理与特征工程 (4) sklearn实战之降维算法PCA与SVD ...

  4. 【详细代码注释】基于CNN卷积神经网络实现随机森林算法

    随机森林算法简介: 随机森林(Random Forest)是一种灵活性很高的机器学习算法. 它的底层是利用多棵树对样本进行训练并预测的一种分类器.在机器学习的许多领域都有广泛地应用. 例如构建医学疾病 ...

  5. 《菜菜的机器学习sklearn课堂》随机森林应用泛化误差调参实例

    随机森林 随机森林 - 概述 集成算法概述 sklearn中的集成算法 随机森林分类器 RandomForestClassifier 重要参数 控制基评估器的参数 n_estimators:基评估器的 ...

  6. Python进行决策树和随机森林

    Python进行决策树和随机森林 一.决策树 第一步,导入库: 第二步,导入数据: 第三步,数据预处理: 第四步,决策树: 第五步,决策树评价: 第六步,生成决策树图. 二.随机森林 第一步,随机森林 ...

  7. 基于蜣螂算法改进的随机森林回归算法 - 附代码

    基于蜣螂算法改进的随机森林回归算法 - 附代码 文章目录 基于蜣螂算法改进的随机森林回归算法 - 附代码 1.数据集 2.RF模型 3.基于蜣螂算法优化的RF 4.测试结果 5.Matlab代码 6. ...

  8. Python 利用SVM,KNN,随机森林进行预测

    Python 利用SVM,KNN,随机森林进行预测 工具:Pycharm,Win10,Python3.6.4 上图是我们的数据文件,最后一列是附近有无超市的标签,1代表有,-1代表没有.可以发现数据维 ...

  9. python椭圆形骨料_一种基于python再生混凝土三维随机球形骨料模型的构建方法与流程...

    本发明涉及建筑技术领域,尤其涉一种基于python再生混凝土三维随机球形骨料模型的构建方法. 背景技术: 再生混凝土是指利用再生粗骨料部分或者全部代替天然骨料配置而成的混凝土,再生混凝土技术的开发和利 ...

  10. 基于 Python 的横版 2D 动作类小游戏

    基于 Python 的横版 2D 动作类小游戏 游戏代码 游戏代码 游戏整体代码(基于 pygame 模块开发) // An highlighted block import pygame impor ...

最新文章

  1. 利用GetPrivateProfileString读取配置文件(.ini)
  2. Jmeter之ForEach控制器(配合正则表达式使用)
  3. 牛客多校2 - All with Pairs(字符串哈希+next数组)
  4. 腾讯测试鸿蒙系统,爆料:荣耀 30 Pro已开始测试华为鸿蒙系统
  5. IIS7 MVC网站生成、发布
  6. 极大似然估计求解多项式分布参数
  7. 图片保存到数据库以及从数据库中Load图片
  8. 2017.3.14 软件包管理器 思考记录
  9. 如何提高WEB程序的效率
  10. Arthas结合Spring容器 线上排查Tips
  11. 首次曝光 唯一全域最高等级背后的阿里云云原生安全全景图
  12. 05、Flutter FFI 结构体
  13. 【FBI WARNING】好东西!!!
  14. 牛客网--15894--WWX的520
  15. 智商黑洞(门萨Mensa测试)3
  16. 统计名著中汉字出现频率
  17. 腾讯 Robotics X 轮腿式机器人
  18. bow键盘 android,一拖三还能秒切换 BOW航世蓝牙键盘体验
  19. 小技巧--获取当前前台显示Activity
  20. BGP专线 解决南北互联互通

热门文章

  1. pythonforin替换字符_Python:用一个字符串替换数组中的数字(Python: Replace a number in array with a string)...
  2. python输出命令_Python中的命令输出解析
  3. Linux环境下安装Hadoop(完全分布式)
  4. python aes加密对于长字符数据丢失_Python 3中AES加密和解密的字符串字节数
  5. 谷粒商城:17.商城业务 — Nginx搭建域名访问
  6. Javascript特效:左侧二维码的显示和隐藏
  7. php 图片水印删除,PHP图片水印
  8. Java直接遍历并读取zip压缩文件的内容以及错误处理
  9. SLAM基础_什么是ORB特征,怎么计算的?
  10. GIS笔记_GDAL c# VS2015 环境配置