nan*_*nue 7 r machine-learning rpart
我想以编程方式测试从树生成的一个规则.在树中,根和叶子(终端节点)之间的路径可以被解释为规则.
在R中,我们可以使用该rpart包并执行以下操作:(在本文中,我将使用iris数据集,仅用于示例目的)
library(rpart)
model <- rpart(Species ~ ., data=iris)
Run Code Online (Sandbox Code Playgroud)
通过这两行,我得到了一个名为的树model,其类是rpart.object(rpart文档,第21页).这个对象有很多信息,并且支持多种方法.特别是,对象有一个frame变量(可以用标准的方式访问: model$frame)(idem)和方法path.rpath(rpart文档,第7页),它给你从根节点到感兴趣的节点的路径(node参数在功能)
在row.names该的frame变量包含了树的节点编号.该var列给出了节点中的split变量,yval拟合值和yval2类概率以及其他信息.
> model$frame
var n wt dev yval complexity ncompete nsurrogate yval2.1 yval2.2 yval2.3 yval2.4 yval2.5 yval2.6 yval2.7
1 Petal.Length 150 150 100 1 0.50 3 3 1.00000000 50.00000000 50.00000000 50.00000000 0.33333333 0.33333333 0.33333333
2 <leaf> 50 50 0 1 0.01 0 0 1.00000000 50.00000000 0.00000000 0.00000000 1.00000000 0.00000000 0.00000000
3 Petal.Width 100 100 50 2 0.44 3 3 2.00000000 0.00000000 50.00000000 50.00000000 0.00000000 0.50000000 0.50000000
6 <leaf> 54 54 5 2 0.00 0 0 2.00000000 0.00000000 49.00000000 5.00000000 0.00000000 0.90740741 0.09259259
7 <leaf> 46 46 1 3 0.01 0 0 3.00000000 0.00000000 1.00000000 45.00000000 0.00000000 0.02173913 0.97826087
Run Code Online (Sandbox Code Playgroud)
但只有被标记为<leaf>在var列是终端节点(叶子).在这种情况下,节点是2,6和7.
如上所述,您可以使用该path.rpart方法提取规则(此方法用于 rattle包和文章Sharma Credit Score中,如下所示:
通常,模型保留预测值的值
predicted.levels <- attr(model, "ylevels")
Run Code Online (Sandbox Code Playgroud)
该值对应yval于model$frame数据集中的列.
对于节点号为7(行号为5)的叶子,预测值为
> ylevels[model$frame[5, ]$yval]
[1] "virginica"
Run Code Online (Sandbox Code Playgroud)
而规则是
> rule <- path.rpart(model, nodes = 7)
node number: 7
root
Petal.Length>=2.45
Petal.Width>=1.75
Run Code Online (Sandbox Code Playgroud)
因此,该规则可以理解为
If Petal.Length >= 2.45 AND Petal.Width >= 1.75 THEN Species = Virginica
Run Code Online (Sandbox Code Playgroud)
我知道我可以测试(在测试数据集中,我将再次使用虹膜数据集)我对此规则有多少真正的正面,对新数据集进行子集化如下
> hits <- subset(iris, Petal.Length >= 2.45 & Petal.Width >= 1.75)
Run Code Online (Sandbox Code Playgroud)
然后计算混淆矩阵
> table(hits$Species, hits$Species == "virginica")
FALSE TRUE
setosa 0 0
versicolor 1 0
virginica 0 45
Run Code Online (Sandbox Code Playgroud)
(注意:我使用相同的虹膜数据集进行测试)
我如何以编程方式评估规则?我可以从规则中提取条件如下
> unlist(rule, use.names = FALSE)[-1]
[1] "Petal.Length>=2.45" "Petal.Width>=1.75"
Run Code Online (Sandbox Code Playgroud)
但是,我怎么能从这里继续?我无法使用该subset功能
提前致谢
注意: 此问题已经过大量编辑,以便更清晰
我可以通过以下方式解决这个问题
免责声明:显然必须是解决这个问题的更好方法,但是这个黑客有效并且做我想做的事......(我对此不是很自豪......是黑客,但有效)
好的,我们开始吧。基本上这个想法是使用包sqldf
如果您检查问题,最后一段代码会将树路径的每一部分放入一个列表中。所以,我将从那里开始
library(sqldf)
library(stringr)
# Transform to a character vector
rule.v <- unlist(rule, use.names=FALSE)[-1]
# Remove all the dots, sqldf doesn't handles dots in names
rule.v <- str_replace_all(rule.v, pattern="([a-zA-Z])\\.([a-zA-Z])", replacement="\\1_\\2")
# We have to remove all the equal signs to 'in ('
rule.v <- str_replace_all(rule.v, pattern="([a-zA-Z0-9])=", replacement="\\1 in ('")
# Embrace all the elements in the lists of values with " ' "
# The last element couldn't be modified in this way (Any ideas?)
rule.v <- str_replace_all(rule.v, pattern=",", replacement="','")
# Close the last element with apostrophe and a ")"
for (i in which(!is.na(str_extract(pattern="in", string=rule.v)))) {
rule.v[i] <- paste(append(rule.v[i], "')"), collapse="")
}
# Collapse all the list in one string joined by " AND "
rule.v <- paste(rule.v, collapse = " AND ")
# Generate the query
# Use any metric that you can get from the data frame
query <- paste("SELECT Species, count(Species) FROM iris WHERE ", rule.v, " group by Species", sep="")
# For debug only...
print(query)
# Execute and print the results
print(sqldf(query))
Run Code Online (Sandbox Code Playgroud)
就这样!
我警告过你,这是黑客行为......
希望这对其他人有帮助......
感谢所有的帮助和建议!