Sri*_*ali 7 r party decision-tree
我有ctree()(party包)的输出,如下所示.如何获取每个终端节点的分割条件列表,如同sns <= 0, dta <= 1; sns <= 0, dta > 1等等?
1) sns <= 0; criterion = 1, statistic = 14655.021
2) dta <= 1; criterion = 1, statistic = 3286.389
3)* weights = 153682
2) dta > 1
4)* weights = 289415
1) sns > 0
5) dta <= 2; criterion = 1, statistic = 1882.439
6)* weights = 245457
5) dta > 2
7) dta <= 6; criterion = 1, statistic = 1170.813
8)* weights = 328582
7) dta > 6
Run Code Online (Sandbox Code Playgroud)
谢谢
这个功能应该可以胜任
CtreePathFunc <- function (ct, data) {
ResulTable <- data.frame(Node = character(), Path = character())
for(Node in unique(where(ct))){
# Taking all possible non-Terminal nodes that are smaller than the selected terminal node
NonTerminalNodes <- setdiff(1:(Node - 1), unique(where(ct))[unique(where(ct)) < Node])
# Getting the weigths for that node
NodeWeights <- nodes(ct, Node)[[1]]$weights
# Finding the path
Path <- NULL
for (i in NonTerminalNodes){
if(any(NodeWeights & nodes(ct, i)[[1]][2][[1]] == 1)) Path <- append(Path, i)
}
# Finding the splitting creteria for that path
Path2 <- SB <- NULL
for(i in 1:length(Path)){
if(i == length(Path)) {
n <- nodes(ct, Node)[[1]]
} else {n <- nodes(ct, Path[i + 1])[[1]]}
if(all(data[which(as.logical(n$weights)), as.character(unlist(nodes(ct,Path[i])[[1]][[5]])[length(unlist(nodes(ct,Path[i])[[1]][[5]]))])] <= as.numeric(unlist(nodes(ct,Path[i])[[1]][[5]])[3]))){
SB <- "<="
} else {SB <- ">"}
Path2 <- paste(c(Path2, paste(as.character(unlist(nodes(ct,Path[i])[[1]][[5]])[length(unlist(nodes(ct,Path[i])[[1]][[5]]))]),
SB,
as.character(unlist(nodes(ct,Path[i])[[1]][[5]])[3]))),
collapse = ", ")
}
# Output
ResulTable <- rbind(ResulTable, cbind(Node = Node, Path = Path2))
}
return(ResulTable)
}
Run Code Online (Sandbox Code Playgroud)
测试
library(party)
airq <- subset(airquality, !is.na(Ozone))
ct <- ctree(Ozone ~ ., data = airq, controls = ctree_control(maxsurrogate = 3))
Result <- CtreePathFunc(ct, airq)
Result
## Node Path
## 1 5 Temp <= 82, Wind > 6.9, Temp <= 77
## 2 3 Temp <= 82, Wind <= 6.9
## 3 6 Temp <= 82, Wind > 6.9, Temp > 77
## 4 9 Temp > 82, Wind > 10.3
## 5 8 Temp > 82, Wind <= 10.3
Run Code Online (Sandbox Code Playgroud)
如果您使用新推荐的partykit实现ctree()而不是旧party包,那么您可以使用该功能.list.rules.party().这尚未正式导出,但可以利用它来提取所需的信息.
library("partykit")
airq <- subset(airquality, !is.na(Ozone))
ct <- ctree(Ozone ~ ., data = airq)
partykit:::.list.rules.party(ct)
## 3 5
## "Temp <= 82 & Wind <= 6.9" "Temp <= 82 & Wind > 6.9 & Temp <= 77"
## 6 8
## "Temp <= 82 & Wind > 6.9 & Temp > 77" "Temp > 82 & Wind <= 10.3"
## 9
## "Temp > 82 & Wind > 10.3"
Run Code Online (Sandbox Code Playgroud)
由于我需要这个函数,但对于分类数据,我或多或少回答了问题 @Jo\xc3\xa3oDaniel (我只测试了分类预测变量),下一个函数:
\n\n# returns string w/o leading or trailing whitespace\n# http://stackoverflow.com/questions/2261079/how-to-trim-leading-and-trailing-whitespace-in-r\ntrim <- function (x) gsub("^\\\\s+|\\\\s+$", "", x)\ngetVariable <- function (x) sub("(.*?)[[:space:]].*", "\\\\1", x)\ngetSimbolo <- function (x) sub("(.*?)[[:space:]](.*?)[[:space:]].*", "\\\\2", x)\n\ngetReglaFinal = function(elemento) { \n x = as.data.frame(strsplit(as.character(elemento),";"))\n Regla = apply(x,1, trim)\n Regla = data.frame(Regla)\n indice = as.numeric(rownames(Regla))\n variable = apply(Regla,1, getVariable)\n simbolo = apply(Regla,1, getSimbolo)\n\n ReglaRaw = data.frame(Regla,indice,variable,simbolo)\n cols <- c( \'variable\' , \'simbolo\' )\n ReglaRaw$tipo_corte <- apply( ReglaRaw[ , cols ] ,1 , paste , collapse = "" )\n #print(ReglaRaw)\n cortes = unique(ReglaRaw$tipo_corte)\n #print(cortes)\n ReglaFinal = ""\n for(i in 1:length(cortes)){\n #print("------------------------------------")\n #print(cortes[i])\n #print("ReglaRaw econtrada")\n #print(ReglaRaw$indice[ReglaRaw$tipo_corte==cortes[i]])\n maximo = max(ReglaRaw$indice[ReglaRaw$tipo_corte==cortes[i]])\n #print(maximo)\n tmp = as.character(ReglaRaw$Regla[ReglaRaw$indice==maximo])\n if(ReglaFinal==""){\n ReglaFinal = tmp\n }else{\n ReglaFinal = paste(ReglaFinal,tmp,sep="; ",collapse="; ")\n }\n }\n return(ReglaFinal)\n}#getReglaFinal\n\nCtreePathFuncAllCat <- function (ct) {\n\n ResulTable <- data.frame(Node = character(), Path = character())\n\n for(Node in unique(where(ct))){\n\n # Taking all possible non-Terminal nodes that are smaller than the selected terminal node\n NonTerminalNodes <- setdiff(1:(Node - 1), unique(where(ct))[unique(where(ct)) < Node])\n\n # Getting the weigths for that node\n NodeWeights <- nodes(ct, Node)[[1]]$weights\n\n # Finding the path\n Path <- NULL\n for (i in NonTerminalNodes){\n if(any(NodeWeights & nodes(ct, i)[[1]][2][[1]] == 1)) Path <- append(Path, i)\n }\n\n # Finding the splitting creteria for that path\n Path2 <- SB <- NULL\n\n variablesNombres <- array()\n variablesPuntos <- list()\n\n for(i in 1:length(Path)){\n n <- nodes(ct, Path[i])[[1]]\n\n if(i == length(Path)) {\n nextNodeID = Node\n } else {\n nextNodeID = Path[i+1]\n } \n\n vec_puntos = as.vector(n[[5]]$splitpoint)\n vec_nombre = n[[5]]$variableName\n vec_niveles = attr(n[[5]]$splitpoint,"levels")\n\n index = 0\n\n if((length(vec_puntos)!=length(vec_niveles)) && (length(vec_niveles)!=0) ){\n index = vec_puntos\n vec_puntos = vector(length=length(vec_niveles))\n vec_puntos[index] = TRUE\n }\n\n if(length(vec_niveles)==0){\n index = vec_puntos\n vec_puntos = n[[5]]$splitpoint\n }\n\n if(index==0){\n if(nextNodeID==n$right$nodeID){\n vec_puntos = !vec_puntos\n }else{\n vec_puntos = !!vec_puntos\n }\n if(i != 1) {\n for(j in 1:(length(Path)-1)){\n if(length(variablesNombres)>=j){\n if( variablesNombres[j]==vec_nombre){\n vec_puntos = vec_puntos*variablesPuntos[[j]]\n }\n }\n }\n vec_puntos = vec_puntos==1\n } \n SB = "="\n }else{\n if(nextNodeID==n$right$nodeID){\n SB = ">"\n }else{\n SB = "<="\n }\n\n }\n\n variablesPuntos[[i]] = vec_puntos \n variablesNombres[i] = vec_nombre\n\n if(length(vec_niveles)==0){\n descripcion = vec_puntos\n }else{\n descripcion = paste(vec_niveles[vec_puntos],collapse=", ")\n }\n Path2 <- paste(c(Path2, paste(c(variablesNombres[i],SB,"{",descripcion, "}"),collapse=" ")\n ),\n collapse = "; ")\n }\n\n # Output\n ResulTable <- rbind(ResulTable, cbind(Node = Node, Path = Path2))\n }\n\n we = weights(ct)\n c0 = as.matrix(where(ct))\n c3 = sapply(we, function(w) sum(w))\n c3 = as.matrix(unique(cbind(c0,c3)))\n Counts = as.matrix(c3[,2])\n c2 = drop(Predict(ct))\n Means = as.matrix(unique(c2))\n\n ResulTable = data.frame(ResulTable,Means,Counts)\n ResulTable = ResulTable[ order(ResulTable$Means) ,]\n\n ResulTable$TruePath = apply(as.data.frame(ResulTable$Path),1, getReglaFinal)\n\n ResulTable2 = ResulTable\n\n ResulTable2$SQL <- paste("WHEN ",gsub("\\\\\'([-+]?([0-9]*\\\\.[0-9]+|[0-9]+))\\\\\'", "\\\\1",gsub("\\\\, ", "\',\'", gsub(" \\\\}", "\')", gsub("\\\\{ ", "(\'", gsub("\\\\;", " AND ", ResulTable2$TruePath)))))," THEN ")\n\n cols <- c( \'SQL\' , \'Node\' )\n ResulTable2$SQL <- apply( ResulTable2[ , cols ] ,1 , paste , collapse = "\'Nodo " )\n\n ResulTable2$SQL <- gsub("THEN\'", "THEN \'", gsub(" \'", "\'", paste(ResulTable2$SQL,"\'")))\n\n ResultadoFinal = list()\n\n ResultadoFinal$PreTable = ResulTable\n ResultadoFinal$Table = ResulTable\n ResultadoFinal$Table$Path = ResultadoFinal$Table$TruePath\n ResultadoFinal$Table$TruePath = NULL\n ResultadoFinal$SQL = paste(" CASE ",paste(ResulTable2$SQL,sep="",collapse=" ")," END ",collapse="")\n\n return(ResultadoFinal)\n}#CtreePathFuncAllCat\nRun Code Online (Sandbox Code Playgroud)\n\n这是一个测试:
\n\nlibrary(party)\n#With ordered factors\nTreeModel1 = ctree(PB~ME+SYMPT+HIST+BSE+DECT, data = mammoexp)\nResult2 <- CtreePathFuncAllCat(TreeModel1)\nResult2\n##$PreTable\n## Node Path Means Counts\n##3 7 DECT > { Somewhat likely }; SYMPT > { Disagree } 6.526316 114\n##2 6 DECT > { Somewhat likely }; SYMPT <= { Disagree } 7.640000 175\n##1 4 DECT <= { Somewhat likely }; DECT > { Not likely } 8.161905 105\n##4 3 DECT <= { Somewhat likely }; DECT <= { Not likely } 9.833333 18\n## TruePath\n##3 DECT > { Somewhat likely }; SYMPT > { Disagree }\n##2 DECT > { Somewhat likely }; SYMPT <= { Disagree }\n##1 DECT <= { Somewhat likely }; DECT > { Not likely }\n##4 DECT <= { Not likely }\n##\n##$Table\n## Node Path Means Counts\n##3 7 DECT > { Somewhat likely }; SYMPT > { Disagree } 6.526316 114\n##2 6 DECT > { Somewhat likely }; SYMPT <= { Disagree } 7.640000 175\n##1 4 DECT <= { Somewhat likely }; DECT > { Not likely } 8.161905 105\n##4 3 DECT <= { Not likely } 9.833333 18\n##\n##$SQL\n##[1] " CASE WHEN DECT > (\'Somewhat likely\') AND SYMPT > (\'Disagree\') THEN \'Nodo 7\' WHEN DECT > (\'Somewhat likely\') AND SYMPT <= (\'Disagree\') THEN \'Nodo 6\' WHEN DECT <= (\'Somewhat likely\') AND DECT > (\'Not likely\') THEN \'Nodo 4\' WHEN DECT <= (\'Not likely\') THEN \'Nodo 3\' END "\n\n\n#With unordered factors\nTreeModel2 = ctree(count~spray, data = InsectSprays)\nplot(TreeModel2, type="simple")\nResult2 <- CtreePathFuncAllCat(TreeModel2)\nResult2\n##$PreTable\n##Node Path Means Counts TruePath\n##2 5 spray = { C, D, E }; spray = { C, E } 2.791667 24 spray = { C, E }\n##3 4 spray = { C, D, E }; spray = { D } 4.916667 12 spray = { D }\n##1 2 spray = { A, B, F } 15.500000 36 spray = { A, B, F }\n##\n##$Table\n##Node Path Means Counts\n##2 5 spray = { C, E } 2.791667 24\n##3 4 spray = { D } 4.916667 12\n##1 2 spray = { A, B, F } 15.500000 36\n##\n##$SQL\n##[1] " CASE WHEN spray = (\'C\',\'E\') THEN \'Nodo 5\' WHEN spray = (\'D\') THEN \'Nodo 4\' WHEN spray = (\'A\',\'B\',\'F\') THEN \'Nodo 2\' END "\n\n#With continuous variables\nairq <- subset(airquality, !is.na(Ozone))\nTreeModel3 <- ctree(Ozone ~ ., data = airq, controls = ctree_control(maxsurrogate = 3))\nResult2 <- CtreePathFuncAllCat(TreeModel3)\nResult2\n##$PreTable\n## Node Path Means Counts\n##1 5 Temp <= { 82 }; Wind > { 6.9 }; Temp <= { 77 } 18.47917 48\n##3 6 Temp <= { 82 }; Wind > { 6.9 }; Temp > { 77 } 31.14286 21\n##4 9 Temp > { 82 }; Wind > { 10.3 } 48.71429 7\n##2 3 Temp <= { 82 }; Wind <= { 6.9 } 55.60000 10\n##5 8 Temp > { 82 }; Wind <= { 10.3 } 81.63333 30\n## TruePath\n##1 Temp <= { 77 }; Wind > { 6.9 }\n##3 Temp <= { 82 }; Wind > { 6.9 }; Temp > { 77 }\n##4 Temp > { 82 }; Wind > { 10.3 }\n##2 Temp <= { 82 }; Wind <= { 6.9 }\n##5 Temp > { 82 }; Wind <= { 10.3 }\n##\n##$Table\n## Node Path Means Counts\n##1 5 Temp <= { 77 }; Wind > { 6.9 } 18.47917 48\n##3 6 Temp <= { 82 }; Wind > { 6.9 }; Temp > { 77 } 31.14286 21\n##4 9 Temp > { 82 }; Wind > { 10.3 } 48.71429 7\n##2 3 Temp <= { 82 }; Wind <= { 6.9 } 55.60000 10\n##5 8 Temp > { 82 }; Wind <= { 10.3 } 81.63333 30\n##\n##$SQL\n##[1] " CASE WHEN Temp <= (77) AND Wind > (6.9) THEN \'Nodo 5\' WHEN Temp <= (82) AND Wind > (6.9) AND Temp > (77) THEN \'Nodo 6\' WHEN Temp > (82) AND Wind > (10.3) THEN \'Nodo 9\' WHEN Temp <= (82) AND Wind <= (6.9) THEN \'Nodo 3\' WHEN Temp > (82) AND Wind <= (10.3) THEN \'Nodo 8\' END "\nRun Code Online (Sandbox Code Playgroud)\n\n更新!现在该函数支持分类变量和数值变量的混合!
\n