Tidymodels:如何从训练数据中提取重要性

Ore*_*tis 3 r machine-learning random-forest tidymodels

我有以下代码,其中我对不同的 mtry 和 min_n 进行一些网格搜索。我知道如何提取提供最高准确度的参数(请参阅第二个代码框)。如何提取训练数据集中每个特征的重要性?我在网上找到的指南显示了如何使用“last_fit”仅在测试数据集中执行此操作。例如指南:https: //www.tidymodels.org/start/case-study/#data-split

set.seed(seed_number)
    data_split <- initial_split(node_strength,prop = 0.8,strata = Group)
    
    train <- training(data_split)
    test <- testing(data_split)
    train_folds <- vfold_cv(train,v = 10)
    
    
    rfc <- rand_forest(mode = "classification", mtry = tune(),
                       min_n = tune(), trees = 1500) %>%
        set_engine("ranger", num.threads = 48, importance = "impurity")
    
    rfc_recipe <- recipe(data = train, Group~.)
    
    rfc_workflow <- workflow() %>% add_model(rfc) %>%
        add_recipe(rfc_recipe)
    
    rfc_result <- rfc_workflow %>%
        tune_grid(train_folds, grid = 40, control = control_grid(save_pred = TRUE),
                  metrics = metric_set(accuracy))
Run Code Online (Sandbox Code Playgroud)

best <- 
        rfc_result %>% 
        select_best(metric = "accuracy")
Run Code Online (Sandbox Code Playgroud)

Jul*_*lge 8

为此,您需要创建一个自定义extract函数,如本文档中所述

\n

对于随机森林变量的重要性,您的函数将如下所示:

\n
get_rf_imp <- function(x) {\n    x %>% \n        extract_fit_parsnip() %>% \n        vip::vi()\n}\n
Run Code Online (Sandbox Code Playgroud)\n

然后您可以将其应用到重新采样中,如下所示(请注意,您会得到一个新.extracts列):

\n
library(tidymodels)\ndata(cells, package = "modeldata")\n\nset.seed(123)\ncell_split <- cells %>% select(-case) %>%\n    initial_split(strata = class)\ncell_train <- training(cell_split)\ncell_test  <- testing(cell_split)\nfolds <- vfold_cv(cell_train)            \n\nrf_spec <- rand_forest(mode = "classification") %>%\n    set_engine("ranger", importance = "impurity")\n\nctrl_imp <- control_grid(extract = get_rf_imp)\n\ncells_res <-\n    workflow(class ~ ., rf_spec) %>%\n    fit_resamples(folds, control = ctrl_imp)\ncells_res\n#> # Resampling results\n#> # 10-fold cross-validation \n#> # A tibble: 10 \xc3\x97 5\n#>    splits             id     .metrics         .notes           .extracts       \n#>    <list>             <chr>  <list>           <list>           <list>          \n#>  1 <split [1362/152]> Fold01 <tibble [2 \xc3\x97 4]> <tibble [0 \xc3\x97 3]> <tibble [1 \xc3\x97 2]>\n#>  2 <split [1362/152]> Fold02 <tibble [2 \xc3\x97 4]> <tibble [0 \xc3\x97 3]> <tibble [1 \xc3\x97 2]>\n#>  3 <split [1362/152]> Fold03 <tibble [2 \xc3\x97 4]> <tibble [0 \xc3\x97 3]> <tibble [1 \xc3\x97 2]>\n#>  4 <split [1362/152]> Fold04 <tibble [2 \xc3\x97 4]> <tibble [0 \xc3\x97 3]> <tibble [1 \xc3\x97 2]>\n#>  5 <split [1363/151]> Fold05 <tibble [2 \xc3\x97 4]> <tibble [0 \xc3\x97 3]> <tibble [1 \xc3\x97 2]>\n#>  6 <split [1363/151]> Fold06 <tibble [2 \xc3\x97 4]> <tibble [0 \xc3\x97 3]> <tibble [1 \xc3\x97 2]>\n#>  7 <split [1363/151]> Fold07 <tibble [2 \xc3\x97 4]> <tibble [0 \xc3\x97 3]> <tibble [1 \xc3\x97 2]>\n#>  8 <split [1363/151]> Fold08 <tibble [2 \xc3\x97 4]> <tibble [0 \xc3\x97 3]> <tibble [1 \xc3\x97 2]>\n#>  9 <split [1363/151]> Fold09 <tibble [2 \xc3\x97 4]> <tibble [0 \xc3\x97 3]> <tibble [1 \xc3\x97 2]>\n#> 10 <split [1363/151]> Fold10 <tibble [2 \xc3\x97 4]> <tibble [0 \xc3\x97 3]> <tibble [1 \xc3\x97 2]>\n
Run Code Online (Sandbox Code Playgroud)\n

由reprex 包(v2.0.1)于 2022-06-19 创建

\n

一旦您获得了这些可变重要性分数摘录,您就可以将unnest()它们(现在,您必须执行两次,因为它是深度嵌套的),然后您可以根据您的喜好进行总结和可视化:

\n
cells_res %>%\n    select(id, .extracts) %>%\n    unnest(.extracts) %>%\n    unnest(.extracts) %>%\n    group_by(Variable) %>%\n    summarise(Mean = mean(Importance),\n              Variance = sd(Importance)) %>%\n    slice_max(Mean, n = 15) %>%\n    ggplot(aes(Mean, reorder(Variable, Mean))) +\n    geom_crossbar(aes(xmin = Mean - Variance, xmax = Mean + Variance)) +\n    labs(x = "Variable importance", y = NULL)\n
Run Code Online (Sandbox Code Playgroud)\n

\n

由reprex 包(v2.0.1)于 2022-06-19 创建

\n