作者:顾全,浙江大学软件工程硕士,现任桃树科技算法工程师

地址:

https://github.com/ZJUguquan/OnlineRandomForest

参与:Cynthia

翻译:本文为天善智能编译,未经容许,禁止转载


介绍

Online Random Forest(ORF) 是由Amir Saffari等人最先提出。之后,Arthur Lui使用Python实现了算法。非常感谢他们的工作。在论文内容和Lui的算法的基础上,我通过R和R6包重构了代码。此外,ORF在此包中的实现,与randomForest结合,使它同时支持增量学习和批量学习,例如:在ORF的基础上构建树,然后通过ORF进行更新。通过这种方法,它将比以前快得多。

安装

if(!require(devtools)) install.packages("devtools")
devtools::install_github("ZJUguquan/OnlineRandomForest")

快速启动

  • 最小举例:增量学习

library(OnlineRandomForest)
param <- list('minSamples'= 1, 'minGain'= 0.1, 'numClasses'= 3, 'x.rng'= dataRange(iris[1:4]))
orf <- ORF$new(param, numTrees = 10)
for (i in 1:150) orf$update(iris[i, 1:4], as.integer(iris[i, 5]))
cat("Mean depth of trees in the forest is:", orf$meanTreeDepth(), "\n")
orf$forest[[2]]$draw()

## Mean depth of trees in the forest is: 3

## Root X4 < 1.21
## |----L: X3 < 2.38
##      |----L: Leaf 1
##      |----R: Leaf 2
## |----R: X4 < 2.15
##      |----L: X1 < 4.92
##           |----L: Leaf 3
##           |----R: Leaf 3
##      |----R: Leaf 3

  • 分类举例

library(OnlineRandomForest)

# data preparation
dat <- iris; dat[,5] <- as.integer(dat[,5])
x.rng <- dataRange(dat[1:4])
param <- list('minSamples'= 2, 'minGain'= 0.2, 'numClasses'= 3, 'x.rng'= x.rng)
ind.gen <- sample(1:150,30) # for generate ORF
ind.updt <- sample(setdiff(1:150, ind.gen), 100) # for uodate ORF
ind.test <- setdiff(setdiff(1:150, ind.gen), ind.updt) # for test

# construct ORF and update
rf <- randomForest::randomForest(factor(Species) ~ ., data = dat[ind.gen, ], maxnodes = 2, ntree = 100)
orf <- ORF$new(param)
orf$generateForest(rf, df.train = dat[ind.gen, ], y.col = "Species")
cat("Mean size of trees in the forest is:", orf$meanTreeSize(), "\n")

## Mean size of trees in the forest is: 3

for (i in ind.updt) {
 orf$update(dat[i, 1:4], dat[i, 5])
}
cat("After update, mean size of trees in the forest is:", orf$meanTreeSize(), "\n")

## After update, mean size of trees in the forest is: 11.9

# predict
orf$confusionMatrix(dat[ind.test, 1:4], dat[ind.test, 5], pretty = T)

##
##  
##    Cell Contents
## |-------------------------|
## |                       N |
## |           N / Row Total |
## |           N / Col Total |
## |-------------------------|
##
##  
## Total Observations in Table:  20
##
##  
##              | actual
##   prediction |         1 |         2 |         3 | Row Total |
## -------------|-----------|-----------|-----------|-----------|
##            1 |         4 |         0 |         0 |         4 |
##              |     1.000 |     0.000 |     0.000 |     0.200 |
##              |     1.000 |     0.000 |     0.000 |           |
## -------------|-----------|-----------|-----------|-----------|
##            2 |         0 |         9 |         2 |        11 |
##              |     0.000 |     0.818 |     0.182 |     0.550 |
##              |     0.000 |     1.000 |     0.286 |           |
## -------------|-----------|-----------|-----------|-----------|
##            3 |         0 |         0 |         5 |         5 |
##              |     0.000 |     0.000 |     1.000 |     0.250 |
##              |     0.000 |     0.000 |     0.714 |           |
## -------------|-----------|-----------|-----------|-----------|
## Column Total |         4 |         9 |         7 |        20 |
##              |     0.200 |     0.450 |     0.350 |           |
## -------------|-----------|-----------|-----------|-----------|
##
##

# compare
table(predict(rf, newdata = dat[ind.test,]) == dat[ind.test, 5])

## FALSE  TRUE
##     9    11

table(orf$predicts(X = dat[ind.test,]) == dat[ind.test, 5])

## FALSE  TRUE
##     2    18

  • 回归举例

# data preparation
if(!require(ggplot2)) install.packages("ggplot2")
data("diamonds", package = "ggplot2")
dat <- as.data.frame(diamonds[sample(1:53000,1000), c(1:6,8:10,7)])
for (col in c("cut","color","clarity")) dat[[col]] <- as.integer(dat[[col]]) # Don't forget this
x.rng <- dataRange(dat[1:9])
param <- list('minSamples'= 10, 'minGain'= 1, 'maxDepth' = 10, 'x.rng'= x.rng)
ind.gen <- sample(1:1000, 800)
ind.updt <- sample(setdiff(1:1000, ind.gen), 100)
ind.test <- setdiff(setdiff(1:1000, ind.gen), ind.updt)

# construct ORF
rf <- randomForest::randomForest(price ~ ., data = dat[ind.gen, ], maxnodes = 20, ntree = 100)
orf <- ORF$new(param)
orf$generateForest(rf, df.train = dat[ind.gen, ], y.col = "price")
orf$meanTreeSize()

## [1] 39

# and update
for (i in ind.updt) {
 orf$update(dat[i, 1:9], dat[i, 10])

}
orf$meanTreeSize()

## [1] 105.7

# predict and compare
if(!require(Metrics)) install.packages("Metrics")
preds.rf <- predict(rf, newdata = dat[ind.test,])
Metrics::rmse(preds.rf, dat$price[ind.test])

## [1] 988.8055

preds <- orf$predicts(dat[ind.test, 1:9])
Metrics::rmse(preds, dat$price[ind.test]) # make progress

## [1] 869.9613

其他用途

  • 在 Tree 类中

ta <- Tree$new("abc", NULL, NULL)
tb <- Tree$new(1, Tree$new(36), Tree$new(3))
tc <- Tree$new(89, tb, ta)
tc$draw()

# update tc
tc$right$updateChildren(Tree$new("666"), Tree$new(999))
tc$right$right$updateChildren(Tree$new("666"), Tree$new(999))
tc$draw()

  • 通过random Forest包配置一个Online random Tree,并升级

# data preparation
library(randomForest)
dat1 <- iris; dat1[,5] <- as.integer(dat1[,5])
rf <- randomForest(factor(Species) ~ ., data = dat1, maxnodes = 3)
treemat1 <- getTree(rf, 1, labelVar=F)
treemat1 <- cbind(treemat1, node.ind = 1:nrow(treemat1))
x.rng1 <- dataRange(dat1[1:4])
param1 <- list('minSamples'= 5, 'minGain'= 0.1, 'numClasses'= 3, 'x.rng'= x.rng1)
ind.gen <- sample(1:150,50) # for generate ORT
ind.updt <- setdiff(1:150, ind.gen) # for update ORT

# origin
ort2 <- ORT$new(param1)
ort2$draw()

## Root 1
##  Leaf 1

# generate a tree

ort2$generateTree(treemat1, df.node = dat1[ind.gen,])
ort2$draw()

## Root X3 < 2.45
## |----L: Leaf 1
## |----R: X3 < 4.75
##      |----L: Leaf 2
##      |----R: Leaf 3

# update this tree
for(i in ind.updt) {
 ort2$update(dat1[i,1:4], dat1[i,5])
}
ort2$draw()

## Root X3 < 2.45
## |----L: Leaf 1
## |----R: X3 < 4.75
##      |----L: Leaf 2
##      |----R: X4 < 2.19
##           |----L: X2 < 3.68
##                |----L: X1 < 7.12
##                     |----L: X3 < 4.06
##                          |----L: Leaf 1
##                          |----R: Leaf 3
##                     |----R: Leaf 3
##                |----R: Leaf 1
##           |----R: Leaf 3

大家都在看

2017年R语言发展报告(国内)

R语言中文社区历史文章整理(作者篇)

R语言中文社区历史文章整理(类型篇)

公众号后台回复关键字即可学习

回复 R                  R语言快速入门及数据挖掘 
回复 Kaggle案例  Kaggle十大案例精讲(连载中)
回复 文本挖掘      手把手教你做文本挖掘
回复 可视化          R语言可视化在商务场景中的应用 
回复 大数据         大数据系列免费视频教程 
回复 量化投资      张丹教你如何用R语言量化投资 
回复 用户画像      京东大数据,揭秘用户画像
回复 数据挖掘     常用数据挖掘算法原理解释与应用
回复 机器学习     人工智能系列之机器学习与实践
回复 爬虫            R语言爬虫实战案例分享

【译】R包介绍:Online Random Forest相关推荐

  1. 跟着Nature学绘图!基于ggplot2的生存曲线绘制R包

    与传统的语言模型不同,深度学习的进步导致了一种新型的预测(自回归)深度语言模型(DLM).使用自我监督的下一个单词预测任务,这些模型在给定的上下文中生成适当的语言响应.在目前的研究中,九名参与者收听了 ...

  2. 台湾大学林轩田机器学习技法课程学习笔记10 -- Random Forest

    红色石头的个人网站:redstonewill.com 上节课我们主要介绍了Decision Tree模型.Decision Tree算法的核心是通过递归的方式,将数据集不断进行切割,得到子分支,最终形 ...

  3. 林轩田机器学习 | 机器学习技法课程笔记10 --- Random Forest

    上节课我们主要介绍了Decision Tree模型.Decision Tree算法的核心是通过递归的方式,将数据集不断进行切割,得到子分支,最终形成树的结构.C&RT算法是决策树比较简单和常用 ...

  4. 探索R包plyr:脱离R中显式循环

    所有R用户接受的第一个"莫名其妙"的原则就是: 不要在R中写显式循环... 不要写显式循环... 不要写循环... 不循环... 不... 我第一次接受到这个"黄金律&q ...

  5. R语言与数据分析(7)-R包的使用

    R包介绍 R包是函数.数据与编译代码以一种定义完善的格式组成的集合, 计算机上存储包的目录称为库library,==函数.libPaths()==可以显示库所在的位置 library() 可以显示库中 ...

  6. R语言使用caret包构建随机森林模型(random forest)构建回归模型、通过method参数指定算法名称、通过ntree参数指定随机森林中树的个数

    R语言使用caret包构建随机森林模型(random forest)构建回归模型.通过method参数指定算法名称.通过ntree参数指定随机森林中树的个数 目录

  7. R语言xgboost包:使用xgboost算法实现随机森林(random forest)模型

    R语言xgboost包:使用xgboost算法实现随机森林(random forest)模型 目录 R语言xgboost包:使用xgboost算法实现随机森林(random forest)模型

  8. R语言实现可理解的随机森林模型(Random Forest)——iml包

    Random Forest 解释模型 1. 介绍 2. 理解随机森林运行机理 2.1导入需要的包 2.2 构建随机森林模型 2.3 RF特征重要性: 2.4 特征对预测结果的影响 2.5 交互作用 2 ...

  9. 使用R构建随机森林回归模型(Random Forest Regressor)

    使用R构建随机森林回归模型(Random Forest Regressor) 目录 使用R构建随机森林回归模型(Random Forest Regressor) 安装包randomForest 缺失值 ...

  10. r语言degseq2_R语言DESeq 包介绍 -

    R语言DESeq包介绍 分析RNA序列数据的一个主要任务是探测基因的差异表达,DESeq包提供了测试差异表达的方法,应用负二项分布和收缩的分布方程估计. 1. 包的安装 输入如下命令,DESeq和相关 ...

最新文章

  1. 【CCNA考试】2010-06-17-杭州-1000(PASS)
  2. Spring MVC 拦截器 interceptor 详解
  3. SpringBootController控制层接收参数的几种常用方式
  4. Linux运维工程师:30道面试题整理
  5. jQuery—淘宝精品服饰案例
  6. [蓝桥杯2018初赛]第几天-日期计算(水题)
  7. 微信小程序 本地mysql_微信小程序系列之使用缓存在本地模拟服务器数据库
  8. AndroidStudio_安卓原生开发_精美自定义多选控件_多选Spinner_MultiSpinner_拿来即用---Android原生开发工作笔记144
  9. LAMP+LNMP视频教程
  10. 递归法:求n个元素的全排列
  11. centos7下永久修改hostname
  12. BC95(ML5515)连接TCP流程
  13. vue中audio实现微信语音播放动画
  14. 湘潭大学信息安全课作业答案1
  15. c语言编程 甲乙丙丁谁是罪犯的题,犯罪大师第二届推理大赛有甲乙丙丁四人答案是什么...
  16. 探访厦航飞机女“医生” 有机务“熊猫”美誉
  17. 军火库(第一期):无线电硬件安全大牛都用哪些利器?
  18. iOS开发- 文件共享 利用iTunes导入文件 并且显示已有文件
  19. excel转word_PDF一键转Word、转Excel、转PPT、转Html、转图片软件
  20. java tan_Java Math tan()用法及代码示例

热门文章

  1. CTO:不要在代码中写 set/get 方法了,逮一次罚款...
  2. 面试热身:5 亿整数的大文件,排个序 ?
  3. 谷歌如何在设计上脱胎换骨
  4. 微服务架构实战(一):微服务架构的优势与不足
  5. Linux 工程师的 6 类好习惯和 23 个教训
  6. three 天空球_three.js添加场景背景和天空盒(skybox)代码示例
  7. JDBCUtils——DBCP
  8. JAVA笔记整理(五),JAVA中的继承
  9. jquery ajax jsonp跨域调用实例代码
  10. js简单操作Cookie