Ore*_*tis 3 r machine-learning random-forest tidymodels
我有以下代码,其中我对不同的 mtry 和 min_n 进行一些网格搜索。我知道如何提取提供最高准确度的参数(请参阅第二个代码框)。如何提取训练数据集中每个特征的重要性?我在网上找到的指南显示了如何使用“last_fit”仅在测试数据集中执行此操作。例如指南:https: //www.tidymodels.org/start/case-study/#data-split
set.seed(seed_number)
data_split <- initial_split(node_strength,prop = 0.8,strata = Group)
train <- training(data_split)
test <- testing(data_split)
train_folds <- vfold_cv(train,v = 10)
rfc <- rand_forest(mode = "classification", mtry = tune(),
min_n = tune(), trees = 1500) %>%
set_engine("ranger", num.threads = 48, importance = "impurity")
rfc_recipe <- recipe(data = train, Group~.)
rfc_workflow <- workflow() %>% add_model(rfc) %>%
add_recipe(rfc_recipe)
rfc_result <- rfc_workflow %>%
tune_grid(train_folds, grid = 40, control = control_grid(save_pred = TRUE),
metrics = metric_set(accuracy))
Run Code Online (Sandbox Code Playgroud)
。
best <-
rfc_result %>%
select_best(metric = "accuracy")
Run Code Online (Sandbox Code Playgroud)
为此,您需要创建一个自定义extract函数,如本文档中所述。
对于随机森林变量的重要性,您的函数将如下所示:
\nget_rf_imp <- function(x) {\n x %>% \n extract_fit_parsnip() %>% \n vip::vi()\n}\nRun Code Online (Sandbox Code Playgroud)\n然后您可以将其应用到重新采样中,如下所示(请注意,您会得到一个新.extracts列):
library(tidymodels)\ndata(cells, package = "modeldata")\n\nset.seed(123)\ncell_split <- cells %>% select(-case) %>%\n initial_split(strata = class)\ncell_train <- training(cell_split)\ncell_test <- testing(cell_split)\nfolds <- vfold_cv(cell_train) \n\nrf_spec <- rand_forest(mode = "classification") %>%\n set_engine("ranger", importance = "impurity")\n\nctrl_imp <- control_grid(extract = get_rf_imp)\n\ncells_res <-\n workflow(class ~ ., rf_spec) %>%\n fit_resamples(folds, control = ctrl_imp)\ncells_res\n#> # Resampling results\n#> # 10-fold cross-validation \n#> # A tibble: 10 \xc3\x97 5\n#> splits id .metrics .notes .extracts \n#> <list> <chr> <list> <list> <list> \n#> 1 <split [1362/152]> Fold01 <tibble [2 \xc3\x97 4]> <tibble [0 \xc3\x97 3]> <tibble [1 \xc3\x97 2]>\n#> 2 <split [1362/152]> Fold02 <tibble [2 \xc3\x97 4]> <tibble [0 \xc3\x97 3]> <tibble [1 \xc3\x97 2]>\n#> 3 <split [1362/152]> Fold03 <tibble [2 \xc3\x97 4]> <tibble [0 \xc3\x97 3]> <tibble [1 \xc3\x97 2]>\n#> 4 <split [1362/152]> Fold04 <tibble [2 \xc3\x97 4]> <tibble [0 \xc3\x97 3]> <tibble [1 \xc3\x97 2]>\n#> 5 <split [1363/151]> Fold05 <tibble [2 \xc3\x97 4]> <tibble [0 \xc3\x97 3]> <tibble [1 \xc3\x97 2]>\n#> 6 <split [1363/151]> Fold06 <tibble [2 \xc3\x97 4]> <tibble [0 \xc3\x97 3]> <tibble [1 \xc3\x97 2]>\n#> 7 <split [1363/151]> Fold07 <tibble [2 \xc3\x97 4]> <tibble [0 \xc3\x97 3]> <tibble [1 \xc3\x97 2]>\n#> 8 <split [1363/151]> Fold08 <tibble [2 \xc3\x97 4]> <tibble [0 \xc3\x97 3]> <tibble [1 \xc3\x97 2]>\n#> 9 <split [1363/151]> Fold09 <tibble [2 \xc3\x97 4]> <tibble [0 \xc3\x97 3]> <tibble [1 \xc3\x97 2]>\n#> 10 <split [1363/151]> Fold10 <tibble [2 \xc3\x97 4]> <tibble [0 \xc3\x97 3]> <tibble [1 \xc3\x97 2]>\nRun Code Online (Sandbox Code Playgroud)\n由reprex 包(v2.0.1)于 2022-06-19 创建
\n一旦您获得了这些可变重要性分数摘录,您就可以将unnest()它们(现在,您必须执行两次,因为它是深度嵌套的),然后您可以根据您的喜好进行总结和可视化:
cells_res %>%\n select(id, .extracts) %>%\n unnest(.extracts) %>%\n unnest(.extracts) %>%\n group_by(Variable) %>%\n summarise(Mean = mean(Importance),\n Variance = sd(Importance)) %>%\n slice_max(Mean, n = 15) %>%\n ggplot(aes(Mean, reorder(Variable, Mean))) +\n geom_crossbar(aes(xmin = Mean - Variance, xmax = Mean + Variance)) +\n labs(x = "Variable importance", y = NULL)\nRun Code Online (Sandbox Code Playgroud)\n
由reprex 包(v2.0.1)于 2022-06-19 创建
\n