Tidymodels - 使用工作流程/配方获取训练数据的预测和指标

Ind*_*led 4 r tidymodels

下面的代码工作正常,没有我所知的错误,但我想添加更多内容。

我想补充的两件事是:

1 - 模型根据训练数据对最终图的预测。我想在适合训练数据的模型上运行collect_predictions()。

2 - 用于查看模型在训练数据上的指标的代码。我想在适合训练数据的模型上运行collect_metrics()。

我如何获得这些信息?

# Setup
library(tidyverse)
library(tidymodels)

parks <- readr::read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2021/2021-06-22/parks.csv')

modeling_df <- parks %>% 
  select(pct_near_park_data, spend_per_resident_data, med_park_size_data) %>% 
  rename(nearness = "pct_near_park_data",
         spending = "spend_per_resident_data",
         acres = "med_park_size_data") %>% 
  mutate(nearness = (parse_number(nearness)/100)) %>% 
  mutate(spending = parse_number(spending))

# Start building models
set.seed(123)
park_split <- initial_split(modeling_df)
park_train <- training(park_split)
park_test <- testing(park_split)

tree_rec <- recipe(nearness ~., data = park_train)
tree_prep <- prep(tree_rec)
juiced <- juice(tree_prep)

tune_spec <- rand_forest(
  mtry = tune(),
  trees = 1000,
  min_n = tune()
) %>% 
  set_mode("regression") %>% 
  set_engine("ranger")

tune_wf <- workflow() %>% 
  add_recipe(tree_rec) %>% 
  add_model(tune_spec)

set.seed(234)
park_folds <- vfold_cv(park_train)

# Make a grid of various different models
doParallel::registerDoParallel()

set.seed(345)
tune_res <- tune_grid(
  tune_wf,
  resamples = park_folds,
  grid = 20,
 control = control_grid(verbose = TRUE)
)

best_rmse <- select_best(tune_res, "rmse")

# Finalize a model with the best grid
final_rf <- finalize_model(
  tune_spec,
  best_rmse
)

final_wf <- workflow() %>% 
  add_recipe(tree_rec) %>% 
  add_model(final_rf)

final_res <- final_wf %>% 
  last_fit(park_split)

# Visualize the performance
# My issue here is that this is only the testing data
# How can I also get this model's performance on the training data?
# I want to plot both with a facetwrap or color indication as well as numerically see the difference with collect_metrics

final_res %>% 
  collect_predictions() %>% 
  ggplot(aes(nearness, .pred)) +
    geom_point() +
    geom_abline()
Run Code Online (Sandbox Code Playgroud)

小智 8

workflow您可以做的就是从中提取经过训练的对象final_res,并使用它来对训练数据集创建预测。

final_model <- final_res$.workflow[[1]]
Run Code Online (Sandbox Code Playgroud)

现在您可以使用augment()测试和训练数据集来可视化性能。

final_model %>% 
  augment(new_data = park_test) %>%
  ggplot(aes(nearness, .pred)) +
  geom_point() +
  geom_abline()
Run Code Online (Sandbox Code Playgroud)

final_model %>% 
  augment(new_data = park_train) %>%
  ggplot(aes(nearness, .pred)) +
  geom_point() +
  geom_abline()
Run Code Online (Sandbox Code Playgroud)

您还可以将结果结合起来,bind_rows()以便更轻松地进行比较。

all_predictions <- bind_rows(
  augment(final_model, new_data = park_train) %>% 
    mutate(type = "train"),
  augment(final_model, new_data = park_test) %>% 
    mutate(type = "test")
)

all_predictions %>%
  ggplot(aes(nearness, .pred)) +
  geom_point() +
  geom_abline() +
  facet_wrap(~type)
Run Code Online (Sandbox Code Playgroud)

所有yardstick度量函数也适用于分组 data.frames。


all_predictions %>%
  group_by(type) %>%
  metrics(nearness, .pred)
#> # A tibble: 6 x 4
#>   type  .metric .estimator .estimate
#>   <chr> <chr>   <chr>          <dbl>
#> 1 test  rmse    standard      0.0985
#> 2 train rmse    standard      0.0473
#> 3 test  rsq     standard      0.725 
#> 4 train rsq     standard      0.943 
#> 5 test  mae     standard      0.0706
#> 6 train mae     standard      0.0350
Run Code Online (Sandbox Code Playgroud)

由reprex 包(v2.0.0)于 2021-06-24 创建