使用R包插入符号,如何根据train()函数的交叉验证结果生成ROC曲线?
说,我做以下事情:
data(Sonar)
ctrl <- trainControl(method="cv",
summaryFunction=twoClassSummary,
classProbs=T)
rfFit <- train(Class ~ ., data=Sonar,
method="rf", preProc=c("center", "scale"),
trControl=ctrl)
Run Code Online (Sandbox Code Playgroud)
训练函数遍历一系列mtry参数并计算ROC AUC.我想看看相关的ROC曲线 - 我该怎么做?
注意:如果用于采样的方法是LOOCV,那么rfFit
将在rfFit$pred
插槽中包含一个非空数据帧,这似乎正是我所需要的.但是,我需要用于"cv"方法(k-fold验证)而不是LOO.
另外:不,roc
曾经包含在以前版本的插入符中的函数不是答案 - 这是一个低级函数,如果您没有每个交叉验证样本的预测概率,则不能使用它.
thi*_*e1e 36
只有savePredictions = TRUE
参数缺失ctrl
(这也适用于其他重采样方法):
library(caret)
library(mlbench)
data(Sonar)
ctrl <- trainControl(method="cv",
summaryFunction=twoClassSummary,
classProbs=T,
savePredictions = T)
rfFit <- train(Class ~ ., data=Sonar,
method="rf", preProc=c("center", "scale"),
trControl=ctrl)
library(pROC)
# Select a parameter setting
selectedIndices <- rfFit$pred$mtry == 2
# Plot:
plot.roc(rfFit$pred$obs[selectedIndices],
rfFit$pred$M[selectedIndices])
Run Code Online (Sandbox Code Playgroud)
也许我遗漏了一些东西,但一个小问题是,train
总是估计AUC值略微不同plot.roc
而且pROC::auc
(绝对差<0.005),尽管twoClassSummary
用于pROC::auc
估算AUC.编辑:我认为这是因为ROC来自train
使用单独的CV集的AUC的平均值,这里我们同时计算所有重采样的AUC以获得总AUC.
更新由于这引起了一些关注,这里有一个解决方案plotROC::geom_roc()
用于ggplot2
:
library(ggplot2)
library(plotROC)
ggplot(rfFit$pred[selectedIndices, ],
aes(m = M, d = factor(obs, levels = c("R", "M")))) +
geom_roc(hjust = -0.4, vjust = 1.5) + coord_equal()
Run Code Online (Sandbox Code Playgroud)
Meg*_*ron 14
在这里,我正在修改@ thei1e的情节,其他人可能会觉得有帮助.
训练模型并做出预测
library(caret)
library(ggplot2)
library(mlbench)
library(plotROC)
data(Sonar)
ctrl <- trainControl(method="cv", summaryFunction=twoClassSummary, classProbs=T,
savePredictions = T)
rfFit <- train(Class ~ ., data=Sonar, method="rf", preProc=c("center", "scale"),
trControl=ctrl)
# Select a parameter setting
selectedIndices <- rfFit$pred$mtry == 2
Run Code Online (Sandbox Code Playgroud)
更新了ROC曲线图
g <- ggplot(rfFit$pred[selectedIndices, ], aes(m=M, d=factor(obs, levels = c("R", "M")))) +
geom_roc(n.cuts=0) +
coord_equal() +
style_roc()
g + annotate("text", x=0.75, y=0.25, label=paste("AUC =", round((calc_auc(g))$AUC, 4)))
Run Code Online (Sandbox Code Playgroud)