测试由Rpart包生成的规则

nan*_*nue 7 r machine-learning rpart

我想以编程方式测试从树生成的一个规则.在树中,根和叶子(终端节点)之间的路径可以被解释为规则.

在R中,我们可以使用该rpart包并执行以下操作:(在本文中,我将使用iris数据集,仅用于示例目的)

library(rpart)
model <- rpart(Species ~ ., data=iris)
Run Code Online (Sandbox Code Playgroud)

通过这两行,我得到了一个名为的树model,其类是rpart.object(rpart文档,第21页).这个对象有很多信息,并且支持多种方法.特别是,对象有一个frame变量(可以用标准的方式访问: model$frame)(idem)和方法path.rpath(rpart文档,第7页),它给你从根节点到感兴趣的节点的路径(node参数在功能)

row.names该的frame变量包含了树的节点编号.该var列给出了节点中的split变量,yval拟合值和yval2类概率以及其他信息.

> model$frame
           var   n  wt dev yval complexity ncompete nsurrogate     yval2.1     yval2.2     yval2.3     yval2.4     yval2.5     yval2.6     yval2.7
1 Petal.Length 150 150 100    1       0.50        3          3  1.00000000 50.00000000 50.00000000 50.00000000  0.33333333  0.33333333  0.33333333
2       <leaf>  50  50   0    1       0.01        0          0  1.00000000 50.00000000  0.00000000  0.00000000  1.00000000  0.00000000  0.00000000
3  Petal.Width 100 100  50    2       0.44        3          3  2.00000000  0.00000000 50.00000000 50.00000000  0.00000000  0.50000000  0.50000000
6       <leaf>  54  54   5    2       0.00        0          0  2.00000000  0.00000000 49.00000000  5.00000000  0.00000000  0.90740741  0.09259259
7       <leaf>  46  46   1    3       0.01        0          0  3.00000000  0.00000000  1.00000000 45.00000000  0.00000000  0.02173913  0.97826087
Run Code Online (Sandbox Code Playgroud)

但只有被标记为<leaf>var列是终端节点(叶子).在这种情况下,节点是2,6和7.

如上所述,您可以使用该path.rpart方法提取规则(此方法用于 rattle包和文章Sharma Credit Score中,如下所示:

通常,模型保留预测值的值

predicted.levels <- attr(model, "ylevels")
Run Code Online (Sandbox Code Playgroud)

该值对应yvalmodel$frame数据集中的列.

对于节点号为7(行号为5)的叶子,预测值为

> ylevels[model$frame[5, ]$yval]
[1] "virginica"
Run Code Online (Sandbox Code Playgroud)

而规则是

> rule <- path.rpart(model, nodes = 7)

 node number: 7 
   root
   Petal.Length>=2.45
   Petal.Width>=1.75
Run Code Online (Sandbox Code Playgroud)

因此,该规则可以理解为

If Petal.Length >= 2.45 AND Petal.Width >= 1.75 THEN Species = Virginica
Run Code Online (Sandbox Code Playgroud)

我知道我可以测试(在测试数据集中,我将再次使用虹膜数据集)我对此规则有多少真正的正面,对新数据集进行子集化如下

> hits <- subset(iris, Petal.Length >= 2.45 & Petal.Width >= 1.75)
Run Code Online (Sandbox Code Playgroud)

然后计算混淆矩阵

> table(hits$Species, hits$Species == "virginica")

             FALSE TRUE
  setosa         0    0
  versicolor     1    0
  virginica      0   45
Run Code Online (Sandbox Code Playgroud)

(注意:我使用相同的虹膜数据集进行测试)

我如何以编程方式评估规则?我可以从规则中提取条件如下

> unlist(rule, use.names = FALSE)[-1]
[1] "Petal.Length>=2.45" "Petal.Width>=1.75" 
Run Code Online (Sandbox Code Playgroud)

但是,我怎么能从这里继续?我无法使用该subset功能

提前致谢

注意: 此问题已经过大量编辑,以便更清晰

nan*_*nue 3

我可以通过以下方式解决这个问题

免责声明:显然必须是解决这个问题的更好方法,但是这个黑客有效并且做我想做的事......(我对此不是很自豪......是黑客,但有效)

好的,我们开始吧。基本上这个想法是使用包sqldf

如果您检查问题,最后一段代码会将树路径的每一部分放入一个列表中。所以,我将从那里开始

        library(sqldf)
        library(stringr)

        # Transform to a character vector
        rule.v <- unlist(rule, use.names=FALSE)[-1]
        # Remove all the dots, sqldf doesn't handles dots in names 
        rule.v <- str_replace_all(rule.v, pattern="([a-zA-Z])\\.([a-zA-Z])", replacement="\\1_\\2")
        # We have to remove all the equal signs to 'in ('
        rule.v <- str_replace_all(rule.v, pattern="([a-zA-Z0-9])=", replacement="\\1 in ('")
        # Embrace all the elements in the lists of values with " ' " 
        # The last element couldn't be modified in this way (Any ideas?) 
        rule.v <- str_replace_all(rule.v, pattern=",", replacement="','")

        # Close the last element with apostrophe and a ")" 
        for (i in which(!is.na(str_extract(pattern="in", string=rule.v)))) {
          rule.v[i] <- paste(append(rule.v[i], "')"), collapse="")
        }

        # Collapse all the list in one string joined by " AND "
        rule.v <- paste(rule.v, collapse = " AND ")

        # Generate the query
        # Use any metric that you can get from the data frame
        query <- paste("SELECT Species, count(Species) FROM iris WHERE ", rule.v, " group by Species", sep="")

        # For debug only...
        print(query)

        # Execute and print the results
        print(sqldf(query))
Run Code Online (Sandbox Code Playgroud)

就这样!

我警告过你,这是黑客行为......

希望这对其他人有帮助......

感谢所有的帮助和建议!