Ric*_*loo 13 r machine-learning random-forest
randomForest对象中提取每个CART模型的变量重要性?rf_mod$forest似乎没有此信息,文档也没有提及。
在R的randomForest程序包中,整个CART模型林的平均变量重要性由给出importance(rf_mod)。
library(randomForest)
df <- mtcars
set.seed(1)
rf_mod = randomForest(mpg ~ .,
data = df,
importance = TRUE,
ntree = 200)
importance(rf_mod)
%IncMSE IncNodePurity
cyl 6.0927875 111.65028
disp 8.7730959 261.06991
hp 7.8329831 212.74916
drat 2.9529334 79.01387
wt 7.9015687 246.32633
qsec 0.7741212 26.30662
vs 1.6908975 31.95701
am 2.5298261 13.33669
gear 1.5512788 17.77610
carb 3.2346351 35.69909
Run Code Online (Sandbox Code Playgroud)
我们还可以使用提取单个树结构getTree。这是第一棵树。
head(getTree(rf_mod, k = 1, labelVar = TRUE))
left daughter right daughter split var split point status prediction
1 2 3 wt 2.15 -3 18.91875
2 0 0 <NA> 0.00 -1 31.56667
3 4 5 wt 3.16 -3 17.61034
4 6 7 drat 3.66 -3 21.26667
5 8 9 carb 3.50 -3 15.96500
6 0 0 <NA> 0.00 -1 19.70000
Run Code Online (Sandbox Code Playgroud)
一种解决方法是增长许多CART(即- ntree = 1),获取每棵树的可变重要性,并平均结果%IncMSE:
# number of trees to grow
nn <- 200
# function to run nn CART models
run_rf <- function(rand_seed){
set.seed(rand_seed)
one_tr = randomForest(mpg ~ .,
data = df,
importance = TRUE,
ntree = 1)
return(one_tr)
}
# list to store output of each model
l <- vector("list", length = nn)
l <- lapply(1:nn, run_rf)
Run Code Online (Sandbox Code Playgroud)
提取,平均和比较步骤。
# extract importance of each CART model
library(dplyr); library(purrr)
map(l, importance) %>%
map(as.data.frame) %>%
map( ~ { .$var = rownames(.); rownames(.) <- NULL; return(.) } ) %>%
bind_rows() %>%
group_by(var) %>%
summarise(`%IncMSE` = mean(`%IncMSE`)) %>%
arrange(-`%IncMSE`)
# A tibble: 10 x 2
var `%IncMSE`
<chr> <dbl>
1 wt 8.52
2 cyl 7.75
3 disp 7.74
4 hp 5.53
5 drat 1.65
6 carb 1.52
7 vs 0.938
8 qsec 0.824
9 gear 0.495
10 am 0.355
# compare to the RF model above
importance(rf_mod)
%IncMSE IncNodePurity
cyl 6.0927875 111.65028
disp 8.7730959 261.06991
hp 7.8329831 212.74916
drat 2.9529334 79.01387
wt 7.9015687 246.32633
qsec 0.7741212 26.30662
vs 1.6908975 31.95701
am 2.5298261 13.33669
gear 1.5512788 17.77610
carb 3.2346351 35.69909
Run Code Online (Sandbox Code Playgroud)
我希望能够直接从一个randomForest对象中提取每棵树的变量重要性,而无需这种回旋方法,该方法需要完全重新运行RF以促进可重复的累积变量重要性图(如下图所示)为mtcars。这里最小的例子。
我知道一棵树的可变重要性在统计上是没有意义的,我也不打算孤立地解释树。我希望它们用于可视化和传达信息,即随着森林中树木的增加,重要性可变的指标在稳定之前会跳来跳去。
训练randomForest模型时,将为整个森林计算重要性分数,并将其直接存储在对象内部。特定树的分数不会保留,因此无法直接从randomForest对象中检索。
不幸的是,关于必须逐步构建林是正确的。好消息是,randomForest对象是独立的,您无需实现自己的对象run_rf。相反,您可以使用stats::update一棵树重新拟合随机森林模型,并一次randomForest::grow添加一棵其他树:
## Starting with a random forest having a single tree,
## grow it 9 times, one tree at a time
rfs <- purrr::accumulate( .init = update(rf_mod, ntree=1),
rep(1,9), randomForest::grow )
## Retrieve the importance scores from each random forest
imp <- purrr::map( rfs, ~importance(.x)[,"%IncMSE"] )
## Combine all results into a single data frame
dplyr::bind_rows( !!!imp )
# # A tibble: 10 x 10
# cyl disp hp drat wt qsec vs am gear carb
# <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
# 1 0 18.8 8.63 1.05 0 1.17 0 0 0 0.194
# 2 0 10.0 46.4 0.561 0 -0.299 0 0 0.543 2.05
# 3 0 22.4 31.2 0.955 0 -0.199 0 0 0.362 5.1
# 4 1.55 24.1 23.4 0.717 0 -0.150 0 0 0.272 5.28
# 5 1.24 22.8 23.6 0.573 0 -0.178 0 0 -0.0259 4.98
# 6 1.03 26.2 22.3 0.478 1.25 0.775 0 0 -0.0216 4.1
# 7 0.887 22.5 22.5 0.406 1.79 -0.101 0 0 -0.0185 3.56
# 8 0.776 19.7 21.3 0.944 1.70 0.105 0 0.0225 -0.0162 3.11
# 9 0.690 18.4 19.1 0.839 1.51 1.24 1.01 0.02 -0.0144 2.77
# 10 0.621 18.4 21.2 0.937 1.32 1.11 0.910 0.0725 -0.114 2.49
Run Code Online (Sandbox Code Playgroud)
数据框显示了功能重要性随每棵其他树如何变化。这是绘图示例的右侧面板。树木本身(用于左侧面板)可以从最终的森林中检索,该森林由给出dplyr::last( rfs )。