R:rpart树使用两个解释变量增长,但在删除不太重要的变量之后不会增长

sde*_*188 5 tree r cart rpart r-caret

数据:我正在使用rsample包中的" attrition "数据集.

问题:使用磨损数据集和rpart库,我可以使用公式"Attrition~OverTime + JobRole"生成树,其中选择OverTime作为第一个分割.但是当我尝试在没有JobRole变量的情况下生长树(即"Attrition~OverTime")时,树不会分裂并仅返回根节点.使用rpart函数以及使用method ="rpart"的插入符号列函数都会发生这种情况.

我对此感到困惑,因为我认为在rpart中实现的CART算法选择了最佳变量以迭代贪婪的方式进行分割,并且没有"向前看"以查看其他变量的存在如何影响其对最佳选择的选择分裂.如果在具有两个解释变量的情况下算法选择OverTime作为值得的第一次拆分,为什么在删除JobRole变量后它不选择OverTime作为值得的第一次拆分?

我在Windows 7中使用R版本3.4.2和RStudio版本1.1.442.

研究:我在这里这里找到了类似的Stack Overflow问题,但都没有完整的答案.

我可以说,rpart文档似乎在第5页说rpart算法不使用"向前看"规则:

解决这两个问题的一种方法是使用预见规则; 但这些在计算上非常昂贵.相反,rpart使用节点的几种杂质或多样性度量之一.

此外,这里这里也有类似的描述.

代码:这是一个代表.任何见解都会很棒 - 谢谢!

suppressPackageStartupMessages(library(rsample))                                                                                                           
#> Warning: package 'rsample' was built under R version 3.4.4
suppressPackageStartupMessages(library(rpart))                                                                                                             
suppressPackageStartupMessages(library(caret))                                                                                                             
suppressPackageStartupMessages(library(dplyr))                                                                                                             
#> Warning: package 'dplyr' was built under R version 3.4.3
suppressPackageStartupMessages(library(purrr))                                                                                                             

#################################################                                                                                                          

# look at data                                                                                                                                             
data(attrition)                                                                                                                                            
attrition_subset <- attrition %>% select(Attrition, OverTime, JobRole)                                                                                     
attrition_subset %>% glimpse()                                                                                                                             
#> Observations: 1,470
#> Variables: 3
#> $ Attrition <fctr> Yes, No, Yes, No, No, No, No, No, No, No, No, No, N...
#> $ OverTime  <fctr> Yes, No, Yes, Yes, No, No, Yes, No, No, No, No, Yes...
#> $ JobRole   <fctr> Sales_Executive, Research_Scientist, Laboratory_Tec...
map_dfr(.x = attrition_subset, .f = ~ sum(is.na(.x)))                                                                                                      
#> # A tibble: 1 x 3
#>   Attrition OverTime JobRole
#>       <int>    <int>   <int>
#> 1         0        0       0

#################################################                                                                                                          

# with rpart                                                                                                                                               
attrition_rpart_w_JobRole <- rpart(Attrition ~ OverTime + JobRole, data = attrition_subset, method = "class", cp = .01)                                    
attrition_rpart_w_JobRole                                                                                                                                  
#> n= 1470 
#> 
#> node), split, n, loss, yval, (yprob)
#>       * denotes terminal node
#> 
#>  1) root 1470 237 No (0.83877551 0.16122449)  
#>    2) OverTime=No 1054 110 No (0.89563567 0.10436433) *
#>    3) OverTime=Yes 416 127 No (0.69471154 0.30528846)  
#>      6) JobRole=Healthcare_Representative,Manager,Manufacturing_Director,Research_Director 126  11 No (0.91269841 0.08730159) *
#>      7) JobRole=Human_Resources,Laboratory_Technician,Research_Scientist,Sales_Executive,Sales_Representative 290 116 No (0.60000000 0.40000000)  
#>       14) JobRole=Human_Resources,Research_Scientist,Sales_Executive 204  69 No (0.66176471 0.33823529) *
#>       15) JobRole=Laboratory_Technician,Sales_Representative 86  39 Yes (0.45348837 0.54651163) *

attrition_rpart_wo_JobRole <- rpart(Attrition ~ OverTime, data = attrition_subset, method = "class", cp = .01)                                             
attrition_rpart_wo_JobRole                                                                                                                                 
#> n= 1470 
#> 
#> node), split, n, loss, yval, (yprob)
#>       * denotes terminal node
#> 
#> 1) root 1470 237 No (0.8387755 0.1612245) *

#################################################                                                                                                          

# with caret                                                                                                                                               
attrition_caret_w_JobRole_non_dummies <- train(x = attrition_subset[ , -1], y = attrition_subset[ , 1], method = "rpart", tuneGrid = expand.grid(cp = .01))
attrition_caret_w_JobRole_non_dummies$finalModel                                                                                                           
#> n= 1470 
#> 
#> node), split, n, loss, yval, (yprob)
#>       * denotes terminal node
#> 
#>  1) root 1470 237 No (0.83877551 0.16122449)  
#>    2) OverTime=No 1054 110 No (0.89563567 0.10436433) *
#>    3) OverTime=Yes 416 127 No (0.69471154 0.30528846)  
#>      6) JobRole=Healthcare_Representative,Manager,Manufacturing_Director,Research_Director 126  11 No (0.91269841 0.08730159) *
#>      7) JobRole=Human_Resources,Laboratory_Technician,Research_Scientist,Sales_Executive,Sales_Representative 290 116 No (0.60000000 0.40000000)  
#>       14) JobRole=Human_Resources,Research_Scientist,Sales_Executive 204  69 No (0.66176471 0.33823529) *
#>       15) JobRole=Laboratory_Technician,Sales_Representative 86  39 Yes (0.45348837 0.54651163) *

attrition_caret_w_JobRole <- train(Attrition ~ OverTime + JobRole, data = attrition_subset, method = "rpart", tuneGrid = expand.grid(cp = .01))            
attrition_caret_w_JobRole$finalModel                                                                                                                       
#> n= 1470 
#> 
#> node), split, n, loss, yval, (yprob)
#>       * denotes terminal node
#> 
#> 1) root 1470 237 No (0.8387755 0.1612245)  
#>   2) OverTimeYes< 0.5 1054 110 No (0.8956357 0.1043643) *
#>   3) OverTimeYes>=0.5 416 127 No (0.6947115 0.3052885)  
#>     6) JobRoleSales_Representative< 0.5 392 111 No (0.7168367 0.2831633) *
#>     7) JobRoleSales_Representative>=0.5 24   8 Yes (0.3333333 0.6666667) *

attrition_caret_wo_JobRole <- train(Attrition ~ OverTime, data = attrition_subset, method = "rpart", tuneGrid = expand.grid(cp = .01))                     
attrition_caret_wo_JobRole$finalModel                                                                                                                      
#> n= 1470 
#> 
#> node), split, n, loss, yval, (yprob)
#>       * denotes terminal node
#> 
#> 1) root 1470 237 No (0.8387755 0.1612245) *
Run Code Online (Sandbox Code Playgroud)

G5W*_*G5W 1

这是完全有道理的。上面有很多额外的代码,所以我将重复重要的部分。

library(rsample)
library(rpart)
data(attrition)

rpart(Attrition ~ OverTime + JobRole, data=attrition)
n= 1470 
node), split, n, loss, yval, (yprob)
      * denotes terminal node

 1) root 1470 237 No (0.83877551 0.16122449)  
   2) OverTime=No 1054 110 No (0.89563567 0.10436433) *
   3) OverTime=Yes 416 127 No (0.69471154 0.30528846)  
     6) JobRole=Healthcare_Representative,Manager,Manufacturing_Director,Research_Director 126  11 No (0.91269841 0.08730159) *
     7) JobRole=Human_Resources,Laboratory_Technician,Research_Scientist,Sales_Executive,Sales_Representative 290 116 No (0.60000000 0.40000000)  
      14) JobRole=Human_Resources,Research_Scientist,Sales_Executive 204  69 No (0.66176471 0.33823529) *
      15) JobRole=Laboratory_Technician,Sales_Representative 86  39 Yes (0.45348837 0.54651163) *

rpart(Attrition ~ OverTime, data=attrition)
n= 1470 
node), split, n, loss, yval, (yprob)
      * denotes terminal node

1) root 1470 237 No (0.8387755 0.1612245) *
Run Code Online (Sandbox Code Playgroud)

看一下第一个模型(有两个变量)。就在根下面我们有:

1) root 1470 237 No (0.83877551 0.16122449)        
    2) OverTime=No 1054 110 No (0.89563567 0.10436433) *      
    3) OverTime=Yes 416 127 No (0.69471154 0.30528846)
Run Code Online (Sandbox Code Playgroud)

该模型继续拆分节点 3 (OverTime=Yes),但使用 JobRole。由于我们在第二个模型中没有 JobRole,因此 rpart 无法进行其他拆分。但请注意,在节点 2 和 3 处,Attrition=No 是多数类。在节点 3 处,69.5% 的实例为“否”,30.5% 为“是”。因此,对于节点 2 和 3,我们将预测“否”。由于分割两侧的预测相同,因此分割是不必要的并被剪掉。你只需要根节点就可以预测所有实例都是No。