Adr*_*baz 10 r formula predict random-forest r-caret
在64位Linux机器上使用R 3.2.0 with caret 6.0-41和randomForest 4.6-10.
当尝试使用公式对使用包中的函数训练predict()的randomForest对象使用该方法时,该函数返回错误.当通过训练和/或使用和而不是公式,这一切都顺利进行.train()caretrandomForest()x=y=
这是一个工作示例:
library(randomForest)
library(caret)
data(imports85)
imp85 <- imports85[, c("stroke", "price", "fuelType", "numOfDoors")]
imp85 <- imp85[complete.cases(imp85), ]
imp85[] <- lapply(imp85, function(x) if (is.factor(x)) x[,drop=TRUE] else x) ## Drop empty levels for factors.
modRf1 <- randomForest(numOfDoors~., data=imp85)
caretRf <- train( numOfDoors~., data=imp85, method = "rf" )
modRf2 <- caretRf$finalModel
modRf3 <- randomForest(x=imp85[,c("stroke", "price", "fuelType")], y=imp85[, "numOfDoors"])
caretRf <- train(x=imp85[,c("stroke", "price", "fuelType")], y=imp85[, "numOfDoors"], method = "rf")
modRf4 <- caretRf$finalModel
p1 <- predict(modRf1, newdata=imp85)
p2 <- predict(modRf2, newdata=imp85)
p3 <- predict(modRf3, newdata=imp85)
p4 <- predict(modRf4, newdata=imp85)
Run Code Online (Sandbox Code Playgroud)
在最后4行中,只有第二行p2 <- predict(modRf2, newdata=imp85)返回以下错误:
Error in predict.randomForest(modRf2, newdata = imp85) :
variables in the training data missing in newdata
Run Code Online (Sandbox Code Playgroud)
似乎这个错误的原因是该predict.randomForest方法用于rownames(object$importance)确定用于训练随机森林的变量的名称object.当看着
rownames(modRf1$importance)
rownames(modRf2$importance)
rownames(modRf3$importance)
rownames(modRf4$importance)
Run Code Online (Sandbox Code Playgroud)
我们看:
[1] "stroke" "price" "fuelType"
[1] "stroke" "price" "fuelTypegas"
[1] "stroke" "price" "fuelType"
[1] "stroke" "price" "fuelType"
Run Code Online (Sandbox Code Playgroud)
所以,不知何故,当使用caret train()带有公式的函数时,会更改对象importance字段中(因子)变量的名称randomForest.
插入符train()函数的公式和非公式版本之间是否真的不一致?或者我错过了什么?
top*_*epo 29
首先,几乎从不使用该$finalModel对象进行预测.使用predict.train.这是一个很好的例子.
某些函数(包括randomForest和train)如何处理虚拟变量之间存在一些不一致.R中使用公式方法的大多数函数会将因子预测变量转换为虚拟变量,因为它们的模型需要数据的数字表示.对此的例外是基于树和规则的模型(可以分类为分类预测变量),朴素贝叶斯和其他一些模型.
所以randomForest会不会当您使用创建虚拟变量randomForest(y ~ ., data = dat),但train(和大多数人)会使用类似呼叫train(y ~ ., data = dat).
发生错误是因为fuelType是一个因素.创建的虚拟变量train不具有相同的名称,因此predict.randomForest无法找到它们.
使用非公式方法train将通过因子预测变量randomForest,一切都会起作用.
TL; DR
train如果您想要相同的级别或使用,请使用非公式方法predict.train
马克斯
| 归档时间: |
|
| 查看次数: |
12194 次 |
| 最近记录: |