在ctree(),partykit包中修改终端节点

Gin*_*kin 1 binary-tree r party

我有一个因变量来通过决策树进行分类.它由三类频率组成:738(19%),426(15%)和1800(66%).正如你想象的那样,预测的类别总是第三个,但树的目的是描述性的,所以它实际上并不重要.问题是,当通过ctree()功能(包partykit)绘制树时,终端节点显示直方图,其显示三个类的出现概率.我需要修改这个输出:我想获得终端节点中每个类相对于类的绝对频率的出现比例.例如,class1中738个参与者中的哪一个属于某个终端节点?每个终端节点将为组成因变量的所有三个类显示该值.

Bellow一个树的图,默认情况下报告终端节点中每个类的普遍性.

Ach*_*eis 6

您始终可以定义自己的面板功能,以绘制每个终端面板窗口的内容.如果您对grid图形有一点了解,并且看看当前终端面板功能是如何定义的,您将看到它是如何工作的.

一个应该做你想做的面板功能是node_terminal()partykit包中(旧party包的重新实现得到了很大改进).但是,因为ctree()不在每个终端节点中存储其预测,所以该node_terminal()功能目前无法开箱即用.我将尝试改进未来版本中的实现,以便可以促进这一点.下面是一个有点参与的例子,应该做你想要的,我希望.

首先,我们使用iris数据拟合分类树(对于一个简单的可重现的例子):

library("partykit")
(ct <- ctree(Species ~ ., data = iris))
## Model formula:
## Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width
## 
## Fitted party:
## [1] root
## |   [2] Petal.Length <= 1.9: setosa (n = 50, err = 0.0%)
## |   [3] Petal.Length > 1.9
## |   |   [4] Petal.Width <= 1.7
## |   |   |   [5] Petal.Length <= 4.8: versicolor (n = 46, err = 2.2%)
## |   |   |   [6] Petal.Length > 4.8: versicolor (n = 8, err = 50.0%)
## |   |   [7] Petal.Width > 1.7: virginica (n = 46, err = 2.2%)
## 
## Number of inner nodes:    3
## Number of terminal nodes: 4
Run Code Online (Sandbox Code Playgroud)

然后我们计算每个终端节点的预测概率表:

(pred <- aggregate(predict(ct, type = "prob"),
  list(predict(ct, type = "node")), FUN = mean))
##   Group.1 setosa versicolor  virginica
## 1       2      1 0.00000000 0.00000000
## 2       5      0 0.97826087 0.02173913
## 3       6      0 0.50000000 0.50000000
## 4       7      0 0.02173913 0.97826087
Run Code Online (Sandbox Code Playgroud)

然后是不那么明显的部分:我们希望将这些预测的概率包含在树本身的终端节点中.为此,我们将递归节点结构强制转换为平面列表,插入预测(适当格式化),并将列表转换回节点结构:

ct_node <- as.list(ct$node)
for(i in 1:nrow(pred)) {
  ct_node[[pred[i,1]]]$info$prediction <- paste(
    format(names(pred)[-1]),
    format(round(pred[i, -1], digits = 3), nsmall = 3)
  )
}
ct$node <- as.partynode(ct_node)
Run Code Online (Sandbox Code Playgroud)

然后,我们可以使用node_terminal面板功能轻松绘制树的图片并插入我们预先格式化的预测:

plot(ct, terminal_panel = node_terminal, tp_args = list(
  FUN = function(node) c("Predictions", node$prediction)))
Run Code Online (Sandbox Code Playgroud)

自定义树

编辑:在a list和a 之间来回强制party实际上已经在包中实现了...我只是忘了它;-)如果你这样做

st <- as.simpleparty(ct)
Run Code Online (Sandbox Code Playgroud)

然后,结果party在每个节点中具有关于预测等的更详细信息.例如,$distribution然后包含每个响应级别的绝对频率.这可以像以前一样轻松地格式化

pred <- function(i) {
  tab <- i$distribution
  tab <- round(prop.table(tab), 3)
  tab <- paste0(names(tab), ":", format(tab, nsmall = 3))
  c("Predictions", tab)
}
Run Code Online (Sandbox Code Playgroud)

这可以传递到node_terminal基本上创建上面的情节.如果希望所有终端节点显示在底行中drop = FALSE,drop = TRUE则可能需要更改为.

plot(st, terminal_panel = node_terminal, tp_args = list(FUN = pred))
Run Code Online (Sandbox Code Playgroud)