ctree() - 如何获取每个终端节点的拆分条件列表?

Sri*_*ali 7 r party decision-tree

我有ctree()(party包)的输出,如下所示.如何获取每个终端节点的分割条件列表,如同sns <= 0, dta <= 1; sns <= 0, dta > 1等等?

1) sns <= 0; criterion = 1, statistic = 14655.021
  2) dta <= 1; criterion = 1, statistic = 3286.389
   3)*  weights = 153682 
  2) dta > 1
   4)*  weights = 289415 
1) sns > 0
  5) dta <= 2; criterion = 1, statistic = 1882.439
   6)*  weights = 245457 
  5) dta > 2
   7) dta <= 6; criterion = 1, statistic = 1170.813
     8)*  weights = 328582 
   7) dta > 6
Run Code Online (Sandbox Code Playgroud)

谢谢

Dav*_*urg 9

这个功能应该可以胜任

 CtreePathFunc <- function (ct, data) {

  ResulTable <- data.frame(Node = character(), Path = character())

  for(Node in unique(where(ct))){
  # Taking all possible non-Terminal nodes that are smaller than the selected terminal node
  NonTerminalNodes <- setdiff(1:(Node - 1), unique(where(ct))[unique(where(ct)) < Node])


  # Getting the weigths for that node
  NodeWeights <- nodes(ct, Node)[[1]]$weights


  # Finding the path
  Path <- NULL
  for (i in NonTerminalNodes){
    if(any(NodeWeights & nodes(ct, i)[[1]][2][[1]] == 1)) Path <- append(Path, i)
  }

  # Finding the splitting creteria for that path
  Path2 <- SB <- NULL

  for(i in 1:length(Path)){
    if(i == length(Path)) {
      n <- nodes(ct, Node)[[1]]
    } else {n <- nodes(ct, Path[i + 1])[[1]]}

    if(all(data[which(as.logical(n$weights)), as.character(unlist(nodes(ct,Path[i])[[1]][[5]])[length(unlist(nodes(ct,Path[i])[[1]][[5]]))])] <= as.numeric(unlist(nodes(ct,Path[i])[[1]][[5]])[3]))){
      SB <- "<="
    } else {SB <- ">"}
    Path2 <- paste(c(Path2, paste(as.character(unlist(nodes(ct,Path[i])[[1]][[5]])[length(unlist(nodes(ct,Path[i])[[1]][[5]]))]),
                                 SB,
                                 as.character(unlist(nodes(ct,Path[i])[[1]][[5]])[3]))),
                   collapse = ", ")
  }

  # Output
  ResulTable <- rbind(ResulTable, cbind(Node = Node, Path = Path2))
  }
  return(ResulTable)
}
Run Code Online (Sandbox Code Playgroud)

测试

library(party)
airq <- subset(airquality, !is.na(Ozone))
ct <- ctree(Ozone ~ ., data = airq,  controls = ctree_control(maxsurrogate = 3))
Result <- CtreePathFunc(ct, airq)
Result 

##   Node                               Path
## 1    5 Temp <= 82, Wind > 6.9, Temp <= 77
## 2    3            Temp <= 82, Wind <= 6.9
## 3    6  Temp <= 82, Wind > 6.9, Temp > 77
## 4    9             Temp > 82, Wind > 10.3
## 5    8            Temp > 82, Wind <= 10.3
Run Code Online (Sandbox Code Playgroud)


Ach*_*eis 6

如果您使用新推荐的partykit实现ctree()而不是旧party包,那么您可以使用该功能.list.rules.party().这尚未正式导出,但可以利用它来提取所需的信息.

library("partykit")
airq <- subset(airquality, !is.na(Ozone))
ct <- ctree(Ozone ~ ., data = airq)
partykit:::.list.rules.party(ct)
##                                      3                                      5 
##             "Temp <= 82 & Wind <= 6.9" "Temp <= 82 & Wind > 6.9 & Temp <= 77" 
##                                      6                                      8 
##  "Temp <= 82 & Wind > 6.9 & Temp > 77"             "Temp > 82 & Wind <= 10.3" 
##                                      9 
##              "Temp > 82 & Wind > 10.3" 
Run Code Online (Sandbox Code Playgroud)


Gal*_*led 5

由于我需要这个函数,但对于分类数据,我或多或少回答了问题 @Jo\xc3\xa3oDaniel (我只测试了分类预测变量),下一个函数:

\n\n
# returns string w/o leading or trailing whitespace\n# http://stackoverflow.com/questions/2261079/how-to-trim-leading-and-trailing-whitespace-in-r\ntrim <- function (x) gsub("^\\\\s+|\\\\s+$", "", x)\ngetVariable <- function (x) sub("(.*?)[[:space:]].*", "\\\\1", x)\ngetSimbolo <- function (x) sub("(.*?)[[:space:]](.*?)[[:space:]].*", "\\\\2", x)\n\ngetReglaFinal = function(elemento) {        \n    x = as.data.frame(strsplit(as.character(elemento),";"))\n    Regla = apply(x,1, trim)\n    Regla = data.frame(Regla)\n    indice = as.numeric(rownames(Regla))\n    variable = apply(Regla,1, getVariable)\n    simbolo = apply(Regla,1, getSimbolo)\n\n    ReglaRaw = data.frame(Regla,indice,variable,simbolo)\n    cols <- c( \'variable\' , \'simbolo\' )\n    ReglaRaw$tipo_corte <- apply(  ReglaRaw[ , cols ] ,1 , paste , collapse = "" )\n    #print(ReglaRaw)\n    cortes = unique(ReglaRaw$tipo_corte)\n    #print(cortes)\n    ReglaFinal = ""\n    for(i in 1:length(cortes)){\n        #print("------------------------------------")\n        #print(cortes[i])\n        #print("ReglaRaw econtrada")\n        #print(ReglaRaw$indice[ReglaRaw$tipo_corte==cortes[i]])\n        maximo = max(ReglaRaw$indice[ReglaRaw$tipo_corte==cortes[i]])\n        #print(maximo)\n        tmp = as.character(ReglaRaw$Regla[ReglaRaw$indice==maximo])\n        if(ReglaFinal==""){\n            ReglaFinal = tmp\n        }else{\n            ReglaFinal = paste(ReglaFinal,tmp,sep="; ",collapse="; ")\n        }\n    }\n    return(ReglaFinal)\n}#getReglaFinal\n\nCtreePathFuncAllCat <- function (ct) {\n\n  ResulTable <- data.frame(Node = character(), Path = character())\n\n  for(Node in unique(where(ct))){\n\n    # Taking all possible non-Terminal nodes that are smaller than the selected terminal node\n    NonTerminalNodes <- setdiff(1:(Node - 1), unique(where(ct))[unique(where(ct)) < Node])\n\n    # Getting the weigths for that node\n    NodeWeights <- nodes(ct, Node)[[1]]$weights\n\n    # Finding the path\n    Path <- NULL\n    for (i in NonTerminalNodes){\n        if(any(NodeWeights & nodes(ct, i)[[1]][2][[1]] == 1)) Path <- append(Path, i)\n    }\n\n    # Finding the splitting creteria for that path\n    Path2 <- SB <- NULL\n\n    variablesNombres <- array()\n    variablesPuntos <- list()\n\n    for(i in 1:length(Path)){\n        n <- nodes(ct, Path[i])[[1]]\n\n        if(i == length(Path)) {\n            nextNodeID = Node\n        } else {\n            nextNodeID = Path[i+1]\n        }       \n\n        vec_puntos  = as.vector(n[[5]]$splitpoint)\n        vec_nombre  = n[[5]]$variableName\n        vec_niveles = attr(n[[5]]$splitpoint,"levels")\n\n        index = 0\n\n        if((length(vec_puntos)!=length(vec_niveles)) && (length(vec_niveles)!=0) ){\n            index = vec_puntos\n            vec_puntos = vector(length=length(vec_niveles))\n            vec_puntos[index] = TRUE\n        }\n\n        if(length(vec_niveles)==0){\n            index = vec_puntos\n            vec_puntos = n[[5]]$splitpoint\n        }\n\n        if(index==0){\n            if(nextNodeID==n$right$nodeID){\n                vec_puntos = !vec_puntos\n            }else{\n                vec_puntos = !!vec_puntos\n            }\n            if(i != 1) {\n                for(j in 1:(length(Path)-1)){\n                    if(length(variablesNombres)>=j){\n                        if( variablesNombres[j]==vec_nombre){\n                            vec_puntos = vec_puntos*variablesPuntos[[j]]\n                        }\n                    }\n                }\n                vec_puntos = vec_puntos==1\n            }   \n            SB = "="\n        }else{\n            if(nextNodeID==n$right$nodeID){\n                SB = ">"\n            }else{\n                SB = "<="\n            }\n\n        }\n\n        variablesPuntos[[i]] = vec_puntos       \n        variablesNombres[i] = vec_nombre\n\n        if(length(vec_niveles)==0){\n            descripcion = vec_puntos\n        }else{\n            descripcion = paste(vec_niveles[vec_puntos],collapse=", ")\n        }\n        Path2 <- paste(c(Path2, paste(c(variablesNombres[i],SB,"{",descripcion, "}"),collapse=" ")\n                        ),\n                       collapse = "; ")\n    }\n\n    # Output\n    ResulTable <- rbind(ResulTable, cbind(Node = Node, Path = Path2))\n  }\n\n    we = weights(ct)\n    c0 = as.matrix(where(ct))\n    c3 = sapply(we, function(w) sum(w))\n    c3 = as.matrix(unique(cbind(c0,c3)))\n    Counts = as.matrix(c3[,2])\n    c2 = drop(Predict(ct))\n    Means = as.matrix(unique(c2))\n\n    ResulTable = data.frame(ResulTable,Means,Counts)\n    ResulTable  = ResulTable[ order(ResulTable$Means) ,]\n\n    ResulTable$TruePath =  apply(as.data.frame(ResulTable$Path),1, getReglaFinal)\n\n    ResulTable2 = ResulTable\n\n    ResulTable2$SQL <- paste("WHEN ",gsub("\\\\\'([-+]?([0-9]*\\\\.[0-9]+|[0-9]+))\\\\\'", "\\\\1",gsub("\\\\, ", "\',\'", gsub(" \\\\}", "\')", gsub("\\\\{ ", "(\'", gsub("\\\\;", " AND ", ResulTable2$TruePath)))))," THEN ")\n\n    cols <- c( \'SQL\' , \'Node\' )\n    ResulTable2$SQL <- apply(  ResulTable2[ , cols ] ,1 , paste , collapse = "\'Nodo " )\n\n    ResulTable2$SQL <- gsub("THEN\'", "THEN \'", gsub(" \'", "\'",  paste(ResulTable2$SQL,"\'")))\n\n    ResultadoFinal = list()\n\n    ResultadoFinal$PreTable = ResulTable\n    ResultadoFinal$Table = ResulTable\n    ResultadoFinal$Table$Path = ResultadoFinal$Table$TruePath\n    ResultadoFinal$Table$TruePath = NULL\n    ResultadoFinal$SQL = paste(" CASE ",paste(ResulTable2$SQL,sep="",collapse=" ")," END ",collapse="")\n\n    return(ResultadoFinal)\n}#CtreePathFuncAllCat\n
Run Code Online (Sandbox Code Playgroud)\n\n

这是一个测试:

\n\n
library(party)\n#With ordered factors\nTreeModel1 = ctree(PB~ME+SYMPT+HIST+BSE+DECT, data = mammoexp)\nResult2 <- CtreePathFuncAllCat(TreeModel1)\nResult2\n##$PreTable\n##  Node                                                Path    Means Counts\n##3    7    DECT > { Somewhat likely }; SYMPT > { Disagree } 6.526316    114\n##2    6   DECT > { Somewhat likely }; SYMPT <= { Disagree } 7.640000    175\n##1    4  DECT <= { Somewhat likely }; DECT > { Not likely } 8.161905    105\n##4    3 DECT <= { Somewhat likely }; DECT <= { Not likely } 9.833333     18\n##                                          TruePath\n##3   DECT > { Somewhat likely }; SYMPT > { Disagree }\n##2  DECT > { Somewhat likely }; SYMPT <= { Disagree }\n##1 DECT <= { Somewhat likely }; DECT > { Not likely }\n##4                             DECT <= { Not likely }\n##\n##$Table\n##  Node                                               Path    Means Counts\n##3    7   DECT > { Somewhat likely }; SYMPT > { Disagree } 6.526316    114\n##2    6  DECT > { Somewhat likely }; SYMPT <= { Disagree } 7.640000    175\n##1    4 DECT <= { Somewhat likely }; DECT > { Not likely } 8.161905    105\n##4    3                             DECT <= { Not likely } 9.833333     18\n##\n##$SQL\n##[1] " CASE  WHEN  DECT > (\'Somewhat likely\') AND  SYMPT > (\'Disagree\')  THEN \'Nodo 7\' WHEN  DECT > (\'Somewhat likely\') AND  SYMPT <= (\'Disagree\')  THEN \'Nodo 6\' WHEN  DECT <= (\'Somewhat likely\') AND  DECT > (\'Not likely\')  THEN \'Nodo 4\' WHEN  DECT <= (\'Not likely\')  THEN \'Nodo 3\'  END "\n\n\n#With unordered factors\nTreeModel2 = ctree(count~spray, data = InsectSprays)\nplot(TreeModel2, type="simple")\nResult2 <- CtreePathFuncAllCat(TreeModel2)\nResult2\n##$PreTable\n##Node                                  Path     Means Counts            TruePath\n##2    5 spray = { C, D, E }; spray = { C, E }  2.791667     24    spray = { C, E }\n##3    4    spray = { C, D, E }; spray = { D }  4.916667     12       spray = { D }\n##1    2                   spray = { A, B, F } 15.500000     36 spray = { A, B, F }\n##\n##$Table\n##Node                Path     Means Counts\n##2    5    spray = { C, E }  2.791667     24\n##3    4       spray = { D }  4.916667     12\n##1    2 spray = { A, B, F } 15.500000     36\n##\n##$SQL\n##[1] " CASE  WHEN  spray = (\'C\',\'E\')  THEN \'Nodo 5\' WHEN  spray = (\'D\')  THEN \'Nodo 4\' WHEN  spray = (\'A\',\'B\',\'F\')  THEN \'Nodo 2\'  END "\n\n#With continuous variables\nairq <- subset(airquality, !is.na(Ozone))\nTreeModel3 <- ctree(Ozone ~ ., data = airq,  controls = ctree_control(maxsurrogate = 3))\nResult2 <- CtreePathFuncAllCat(TreeModel3)\nResult2\n##$PreTable\n##  Node                                           Path    Means Counts\n##1    5 Temp <= { 82 }; Wind > { 6.9 }; Temp <= { 77 } 18.47917     48\n##3    6  Temp <= { 82 }; Wind > { 6.9 }; Temp > { 77 } 31.14286     21\n##4    9                 Temp > { 82 }; Wind > { 10.3 } 48.71429      7\n##2    3                Temp <= { 82 }; Wind <= { 6.9 } 55.60000     10\n##5    8                Temp > { 82 }; Wind <= { 10.3 } 81.63333     30\n##                                     TruePath\n##1                Temp <= { 77 }; Wind > { 6.9 }\n##3 Temp <= { 82 }; Wind > { 6.9 }; Temp > { 77 }\n##4                Temp > { 82 }; Wind > { 10.3 }\n##2               Temp <= { 82 }; Wind <= { 6.9 }\n##5               Temp > { 82 }; Wind <= { 10.3 }\n##\n##$Table\n##  Node                                          Path    Means Counts\n##1    5                Temp <= { 77 }; Wind > { 6.9 } 18.47917     48\n##3    6 Temp <= { 82 }; Wind > { 6.9 }; Temp > { 77 } 31.14286     21\n##4    9                Temp > { 82 }; Wind > { 10.3 } 48.71429      7\n##2    3               Temp <= { 82 }; Wind <= { 6.9 } 55.60000     10\n##5    8               Temp > { 82 }; Wind <= { 10.3 } 81.63333     30\n##\n##$SQL\n##[1] " CASE  WHEN  Temp <= (77) AND  Wind > (6.9)  THEN \'Nodo 5\' WHEN  Temp <= (82) AND  Wind > (6.9) AND  Temp > (77)  THEN \'Nodo 6\' WHEN  Temp > (82) AND  Wind > (10.3)  THEN \'Nodo 9\' WHEN  Temp <= (82) AND  Wind <= (6.9)  THEN \'Nodo 3\' WHEN  Temp > (82) AND  Wind <= (10.3)  THEN \'Nodo 8\'  END "\n
Run Code Online (Sandbox Code Playgroud)\n\n

更新!现在该函数支持分类变量和数值变量的混合!

\n

  • 很好用,但是,它似乎只适用于分类变量:当我在airct树CtreePathFuncAllCat(ct)的结果上尝试这个时,它返回分割字段,但不返回分割标准。知道如何获取分类变量和连续变量的路径吗? (2认同)