k_fold_cv函数——bartMachine包内函数详解

  • R bartMachine包下载
  • 一、 函数包内原文
    • 1、k_fold_cv函数底层代码
    • 2、 k_fold_cv参数解释原文
    • 3、 k_fold_cv返回结果参数解释原文
  • 二、k_fold_cv函数中文解析
    • 1、k_fold_cv函数作用
    • 2、 k_fold_cv参数中文解析
    • 3、不同模型k_fold_cv函数返回参数中文解析
      • (1) 回归模型
      • (2)分类模型
  • 三、k_fold_cv应用实例
    • 1、回归模型
      • (1)回归模型k_fold_cv的使用实例
      • (2)回归模型k_fold_cv返回结果展示(结果较长的展示部分)
    • 2、分类模型
      • (1)分类模型k_fold_cv的使用实例
      • (2)分类模型k_fold_cv返回结果展示(结果较长的展示部分)
  • 特别声明

R bartMachine包下载

R 所有包下载地址 :
https://cran.r-project.org/web/packages/available_packages_by_name.html

R bartMachine包下载地址 :
https://cran.r-project.org/web/packages/bartMachine/index.html

一、 函数包内原文

1、k_fold_cv函数底层代码

注:直接运行 k_fold_cv 即可得到

function (X, y, k_folds = 5, folds_vec = NULL, verbose = FALSE, ...)
{args = list(...)args$serialize = FALSEif (class(X) != "data.frame") {stop("The training data X must be a data frame.")}if (!(class(y) %in% c("numeric", "integer", "factor"))) {stop("Your response must be either numeric, an integer or a factor with two levels.\n")}if (!is.null(folds_vec) & class(folds_vec) != "integer") stop("folds_vec must be an a vector of integers specifying the indexes of each folds.")y_levels = levels(y)if (class(y) == "numeric" || class(y) == "integer") {pred_type = "regression"}else if (class(y) == "factor" & length(y_levels) == 2) {pred_type = "classification"}n = nrow(X)Xpreprocess = pre_process_training_data(X)$datap = ncol(Xpreprocess)if (is.null(folds_vec)) {if (k_folds == Inf) {k_folds = n}if (k_folds <= 1 || k_folds > n) {stop("The number of folds must be at least 2 and less than or equal to n, use \"Inf\" for leave one out")}temp = rnorm(n)folds_vec = cut(temp, breaks = quantile(temp, seq(0, 1, length.out = k_folds + 1)), include.lowest = T, labels = F)}else {k_folds = length(unique(folds_vec))}if (pred_type == "regression") {L1_err = 0L2_err = 0yhat_cv = numeric(n)}else {phat_cv = numeric(n)yhat_cv = factor(n, levels = y_levels)confusion_matrix = matrix(0, nrow = 3, ncol = 3)rownames(confusion_matrix) = c(paste("actual", y_levels), "use errors")colnames(confusion_matrix) = c(paste("predicted", y_levels), "model errors")}Xy = data.frame(Xpreprocess, y)for (k in 1:k_folds) {cat(".")train_idx = which(folds_vec != k)test_idx = setdiff(1:n, train_idx)test_data_k = Xy[test_idx, ]training_data_k = Xy[train_idx, ]bart_machine_cv = do.call(build_bart_machine, c(list(X = training_data_k[, 1:p, drop = FALSE], y = training_data_k[, (p + 1)], run_in_sample = FALSE, verbose = verbose), args))predict_obj = bart_predict_for_test_data(bart_machine_cv, test_data_k[, 1:p, drop = FALSE], test_data_k[, (p + 1)])if (pred_type == "regression") {L1_err = L1_err + predict_obj$L1_errL2_err = L2_err + predict_obj$L2_erryhat_cv[test_idx] = predict_obj$y_hat}else {phat_cv[test_idx] = predict_obj$p_hatyhat_cv[test_idx] = predict_obj$y_hattab = table(factor(test_data_k$y, levels = y_levels), factor(predict_obj$y_hat, levels = y_levels))confusion_matrix[1:2, 1:2] = confusion_matrix[1:2, 1:2] + tab}}cat("\n")if (pred_type == "regression") {list(y_hat = yhat_cv, L1_err = L1_err, L2_err = L2_err, rmse = sqrt(L2_err/n), PseudoRsq = 1 - L2_err/sum((y - mean(y))^2), folds = folds_vec)}else {confusion_matrix[3, 1] = round(confusion_matrix[2, 1]/(confusion_matrix[1, 1] + confusion_matrix[2, 1]), 3)confusion_matrix[3, 2] = round(confusion_matrix[1, 2]/(confusion_matrix[1, 2] + confusion_matrix[2, 2]), 3)confusion_matrix[1, 3] = round(confusion_matrix[1, 2]/(confusion_matrix[1, 1] + confusion_matrix[1, 2]), 3)confusion_matrix[2, 3] = round(confusion_matrix[2, 1]/(confusion_matrix[2, 1] + confusion_matrix[2, 2]), 3)confusion_matrix[3, 3] = round((confusion_matrix[1, 2] + confusion_matrix[2, 1])/sum(confusion_matrix[1:2, 1:2]), 3)list(y_hat = yhat_cv, phat = phat_cv, confusion_matrix = confusion_matrix, misclassification_error = confusion_matrix[3, 3], folds = folds_vec)}
}

2、 k_fold_cv参数解释原文

函数:k_fold_cv(X, y, k_folds = 5, folds_vec = NULL, verbose = FALSE, …)

参数 内容
X Data frame of predictors. Factors are automatically converted to dummies interally.
y Vector of response variable. If y is numeric or integer, a BART model for regression is built. If y is a factor with two levels, a BART model for classification is built.
k_folds Number of folds to cross-validate over. This argument is ignored if folds_vec is non-null.
folds_vec An integer vector of indices specifying which fold each observation belongs to.
verbose Prints information about progress of the algorithm to the screen.
Additional arguments to be passed to build_bart_machine.

3、 k_fold_cv返回结果参数解释原文

对于回归模型,将返回包含以下信息的列表:

参数 内容
y_hat Predictions for the observations computed on the fold for which the observation was omitted from the training set.
L1_err Aggregate L1 error across the folds.
L2_err Aggregate L1 error across the folds.
rmse Aggregate RMSE across the folds.
PseudoRsq Calculated as 1 - SSE / SST where SSE is the sum of square errors in the training data and SST is the sample variance of the response times n-1.
folds Vector of indices specifying which fold each observation belonged to.

对于分类模型,将返回包含以下信息的列表:

参数 内容
y_hat Class predictions for the observations computed on the fold for which the observation was omitted from the training set.
p_hat Probability estimates for the observations computed on the fold for which the observation was omitted from the training set.
confusion_matrix Aggregate confusion matrix across the folds.
misclassification_error Total misclassification error across the folds.
folds Vector of indices specifying which fold each observation belonged to.

二、k_fold_cv函数中文解析

1、k_fold_cv函数作用

简单来说,就是利用k折交叉验证分别建模,计算出不同的返回值例如(回归:L2_err 、rmse、PseudoRsq值;分类:confusion_matrix、misclassification_error值)。根据返回值,判断当前所指定的建模模型是否存在过拟合现象。

2、 k_fold_cv参数中文解析

参数 内容
X 自变量数据集,可以是连续型自变量,也可以是离散型自变量。
y 因变量或响应变量数据集。如果数据类型是numeric(数字类型)和integer(整型),则建立回归模型;如果数据类型(factor) 因子,且只能是两分类,则会建立分类模型。
k_folds 交叉验证的折数。也就是几折交叉验证。
folds_vec 对训练集设定每一个样本属于哪一个分类(几折就几个分类),如果不设定就会随机设定。
verbose 打印相关建模进度及其相关信息到屏幕上。
设定建立Bart模型的其他参数,例如先验参数(alpha、beta、k、q),建模过程参数(num_burn_in、num_iterations_after_burn_in)等。参数具体含义可参见R bartMachine包内bartMachine函数参数详解

3、不同模型k_fold_cv函数返回参数中文解析

(1) 回归模型

对于回归模型,将返回包含以下信息的列表:

参数 内容
y_hat 对因变量的预测值的集合。对n折来说就是每一次(n-1/n)的训练集建模对另外(1/n)的预测集的预测值,然后汇总得到的预测集的集合。在代码第70、83行代码体现。
L1_err L1范数,也就是差值绝对值之和,计算公式为(∑k=0n∣yi−f^(xi)∣\displaystyle\sum_{k=0}^n\vert y_i-\hat{f}(x_i) \rvertk=0∑n​∣yi​−f^​(xi​)∣)。对n折来说就是每(1/n)次建模的L1范数的加和。在代码第69、83行代码体现。
L2_err L2范数,也就是差值平方之和,计算公式为(∑k=0n(yi−f^(xi))2\displaystyle\sum_{k=0}^n (y_i-\hat{f}(x_i) )^2k=0∑n​(yi​−f^​(xi​))2)。对n折来说就是每(1/n)次建模的L2范数的加和。在代码第68、83行代码体现。
rmse 均方根误差值,计算公式为(∑k=0n(yi−f^(xi))2/(n−1)\sqrt{\displaystyle\sum_{k=0}^n (y_i-\hat{f}(x_i) )^2/(n-1)}k=0∑n​(yi​−f^​(xi​))2/(n−1)​)。 在k_fold_cv中的算法为{ rmse = sqrt(L2_err/n) },在代码第84行代码体现。
PseudoRsq 伪R^2值,计算为1-SSE/SST,其中SSE是因变量的平方误差之和,SST是响应变量n-1的样本方差。在k_fold_cv中的算法为{ PseudoRsq = 1 - L2_err/sum((y - mean(y))^2) },在代码第84、85行代码体现。
folds 返回每一个训练集样本所属的组的编号。

(2)分类模型

对于分类模型,将返回包含以下信息的列表:

参数 内容
y_hat 对因变量的分类结果预测值的集合。对n折来说就是每一次(n-1/n)的训练集建模对另外(1/n)的预测集的分类结果预测值,然后汇总得到的预测集的集合(是在得到了模型估计的概率根据判定阈值得到的分类结果预测值)。在代码第74、99行代码体现。
p_hat 对因变量的模型估计的概率的集合。对n折来说就是每一次(n-1/n)的训练集建模对另外(1/n)的预测集的概率值估计,然后汇总得到的预测集的估计概率值集合。即在得到分类阈值后可以得到y_hat的分类结果。在代码第73、99行代码体现。
confusion_matrix 对n个模型预测集预测的混淆矩阵的加和。在代码第75、76、77、99行代码体现。
misclassification_error 是用上面得到的混淆矩阵,计算得到的预测误差值。在代码第96、97、100行代码体现。
folds 返回每一个训练集样本所属的组的编号。

三、k_fold_cv应用实例

1、回归模型

(1)回归模型k_fold_cv的使用实例

options(java.parameters = "-Xmx10g")library(ggplot2)
library(bartMachine)
library(reshape2)
library(knitr)
library(ggplot2)
library(GGally)
library(scales)##读取数据
data<-read.csv(file="C:/Users/LHW/Desktop/boston_housing_data.csv",header=T,sep=",")
head(data)
n=dim(data)
n
data1<-data[0,] #MEDV值不为NA的样本
data2<-data[0,] #MEDV值为NA的样本,后面用模型来填补缺失值#循环分离MEDV值为NA的样本到data2,其他的样本到data1
i=1
for (i in 1:n[1]) {if(is.na(data[i,14])) {data2 <- rbind(data.frame(data2),data.frame(data[i,]))}else{data1 <- rbind(data.frame(data1),data.frame(data[i,]))}print(i)
}#随机种子
set.seed(100)
#按照90%和10%比例划分训练集和测试集
index2=sample(x=2,size=nrow(data1),replace=TRUE,prob=c(0.9,0.1))#训练集
train2=data1[index2==1,]
head(train2)
x=train2[,-c(14)]
y=train2[,14] #预测集
data2=data1[index2==2,]
x.test_data=data2[,-c(14)]
head(data2)
xp=x.test_data
yp=data2[,14]   #建立Bart模型
res = bartMachine(x,y,num_trees = 50,k=2,nu=3,q=0.9,num_burn_in = 50,num_iterations_after_burn_in = 1000,flush_indices_to_save_RAM = FALSE,seed = 1313, verbose = T)
print(res)#使用 k 倍交叉验证,评估过度拟合的水平。
kk<-k_fold_cv(x, y, k_folds = 10)
kk

(2)回归模型k_fold_cv返回结果展示(结果较长的展示部分)


从结果中我们可以看出L2_err、PseudoRsp的值都没有很小,即没有与原模型的值差距很大,说明所建立的模型过拟合的现象不明显,建模泛化效果比较好。

2、分类模型

(1)分类模型k_fold_cv的使用实例

options(java.parameters = "-Xmx10g")library(ggplot2)
library(bartMachine)
library(reshape2)
library(knitr)
library(ggplot2)
library(GGally)
library(scales)##读取数据
data<-read.csv(file="C:/Users/LHW/Desktop/tt.csv",header=T,sep=",")
head(data)
n=dim(data)
n#随机种子
set.seed(1000)
#按照80%和20%比例划分训练集和测试集
index2=sample(x=2,size=nrow(data),replace=TRUE,prob=c(0.8,0.2))#训练集
train2=data[index2==1,]
head(train2)
x=train2[,-c(1)]
y=train2[,1]
y = factor(y)#预测集
data2=data[index2==2,]
x.test_data=data2[,-c(1)]
head(data2)
xp=x.test_data
yp=data2[,1]
yp = factor(yp)#建立Bart模型
res = bartMachine(x,y,prob_rule_class = 0.5)
print(res)#使用 k 倍交叉验证,评估过度拟合的水平。
kk<-k_fold_cv(x, y, k_folds = 10)
kk

(2)分类模型k_fold_cv返回结果展示(结果较长的展示部分)


同样,从结果中我们也可以看出misclassification的值都没有很大,即没有与原模型的值差距很大,说明所建立的模型过拟合的现象不明显,建模泛化效果比较好。

特别声明

作者也是初学者,水平有限,文章中会存在一定的缺点和谬误,恳请读者多多批评、指正和交流!

k_fold_cv函数——bartMachine包内函数详解相关推荐

  1. python哪里下载import包-【Python实战】模块和包导入详解(import)

    1.模块(module) 1.1 模块定义 通常模块为一个.py文件,其他可作为module的文件类型还有".pyo".".pyc".".pyd&qu ...

  2. 四.卡尔曼滤波器(EKF)开发实践之四: ROS系统位姿估计包robot_pose_ekf详解

    本系列文章主要介绍如何在工程实践中使用卡尔曼滤波器,分七个小节介绍: 一.卡尔曼滤波器开发实践之一: 五大公式 二.卡尔曼滤波器开发实践之二:  一个简单的位置估计卡尔曼滤波器 三.卡尔曼滤波器(EK ...

  3. python解包什么意思_python解包用法详解

    对于一堆资料,我们可以把它分给不同的人使用,这个分散的过程,我们可以看成是解包方法是实现.当然实际python解包的使用会相对复杂一点,我们会对常见的列表.函数等进行操作.下面我们就Python解包的 ...

  4. Oracle之DBMS_SQL包用法详解

    Oracle之DBMS_SQL包用法详解 原文  http://zhangzhongjie.iteye.com/blog/1948093 通常运用 DBMS_SQL 包一般分为 如下 几步: 1. o ...

  5. 在python中使用关键字define定义函数_python自定义函数def的应用详解

    这里是三岁,来和大家唠唠自定义函数,这一个神奇的东西,带大家白话玩转自定义函数 自定义函数,编程里面的精髓! def 自定义函数的必要函数:def 使用方法:def 函数名(参数1,参数2,参数-): ...

  6. 函数assert()详解

    函数assert()详解: 断言assert是一个宏,该宏在<assert>中,,当使用assert时候,给他个参数,即一个判读为真的表达式.预处理器产生测试该断言的代码,如果断言不为真, ...

  7. php。defined,PHP defined()函数的使用图文详解

    PHP defined()函数的使用图文详解 PHP defined() 函数 例子 定义和用法 defined() 函数检查某常量是否存在. 若常量存在,则返回 true,否则返回 false. 语 ...

  8. python中tile的用法_python3中numpy函数tile的用法详解

    tile函数位于python模块 numpy.lib.shape_base中,他的功能是重复某个数组.比如tile(A,n),功能是将数组A重复n次,构成一个新的数组,我们还是使用具体的例子来说明问题 ...

  9. Delphi Format函数功能及用法详解

    DELPHI中Format函数功能及用法详解 DELPHI中Format函数功能及用法详解function Format(const Format: string; const Args: array ...

  10. python中的json函数_python中装饰器、内置函数、json的详解

    装饰器 装饰器本质上是一个Python函数,它可以让其他函数在不需要做任何代码变动的前提下增加额外功能,装饰器的返回值也是一个函数对象. 先看简单例子: def run(): time.sleep(1 ...

最新文章

  1. windows下安装awstats来分析apache的访问日志
  2. 使用libjpeg进行图片压缩(哈夫曼算法,无损压缩)
  3. CentOs 6.0 下安装cacti的syslog插件
  4. 微信电脑网页二维码扫描登录简单实现
  5. UVA11549计算器谜题
  6. Leet Code OJ 136. Single Number [Difficulty: Medium]
  7. DeFi 中的 De 是什么意思?这对区块链行业意味着什么?
  8. ijkplayer播放器h265解码能力调研
  9. 「硬见小百科」14个常用的电路基础公式换算
  10. SQL server 期末复习
  11. Python中的切片(Slice)操作详解
  12. vue访问子组件实例或子元素
  13. PYTHON——自然间断点分级法
  14. 【Flink基础】-- 高效学习 flink kubernetes operator 的一些建议
  15. java的格式控制符_C语言的格式控制符
  16. WPS表格 下拉列表 两级下拉列表联动 多级下拉列表联动
  17. 最全电商分类信息(02)
  18. Linux SCP跨服务器传输文件
  19. Safari 无法播放视频
  20. Environment variable ORACLE_UNQNAME not defined. Please set ORACLE_UNQNAME to da tabase unique name.

热门文章

  1. php展厅控制系统,展厅中控系统
  2. 题解 P1894 【[USACO4.2]完美的牛栏The Perfect Stall】
  3. 控制面板设置java_win10系统打开java控制面板的具体技巧
  4. linux 内核---------董昊 ( Robin Dong ) and OenHan
  5. Git简介之部分易混淆命令的简单介绍
  6. python中hist的用法总结
  7. poj 1900 Game
  8. 详解wait/waitpid的参数:status
  9. 刘宇凡:罗永浩的锤子情怀只能拿去喂狗
  10. U8C报表模板已设置,任务已分配仍无法查看报表数据