如何保存 Tidymodels Lightgbm 模型以供重复使用

gri*_*ngs 5 r save prediction lightgbm tidymodels

我有以下代码用于使用模型创建tidymodels工作流程lightgbm。但是,当我尝试保存到.rds对象并预测时出现了一些问题

\n
library(AmesHousing)\nlibrary(treesnip)\nlibrary(lightgbm)\nlibrary(tidymodels)\ntidymodels_prefer()\n\n### Model ###\n\n# data\ndata <- make_ames() %>%\n  janitor::clean_names()\n\ndata <- subset(data, select = c(sale_price, bedroom_abv_gr, bsmt_full_bath, bsmt_half_bath, enclosed_porch, fireplaces,\n                                full_bath, half_bath, kitchen_abv_gr, garage_area, garage_cars, gr_liv_area, lot_area,\n                                lot_frontage, year_built, year_remod_add, year_sold))\n\ndata$id <- c(1:nrow(data))\n\ndata <- data %>%\n  mutate(id = as.character(id)) %>%\n  select(id, everything())\n\n# model specification\n\nlgbm_model <- boost_tree(\n  mtry = 7,\n  trees = 347,\n  min_n = 10,\n  tree_depth = 12,\n  learn_rate = 0.0106430579211173,\n  loss_reduction = 0.000337948798058139,\n) %>%\n  set_mode("regression") %>%\n  set_engine("lightgbm", objective = "regression")\n\n# recipe and workflow\n\nlgbm_recipe <- recipe(sale_price ~., data = data) %>%\n  update_role(id, new_role = "ID") %>%\n  step_corr(all_predictors(), threshold = 0.7) %>%\n  prep()\n\nlgbm_workflow <- workflow() %>% \n  add_recipe(lgbm_recipe) %>%\n  add_model(lgbm_model)  \n  \n# fit workflow\n\nfit_lgbm_workflow <- lgbm_workflow %>%\n  fit(data = data)\n\n# predict\n\ndata_predict <- subset(data, select = -c(sale_price))\npredict(fit_lgbm_workflow, new_data = data_predict)\n\n\n### CASE 1: Save the workflow with SaveRDS()\n\nsaveRDS(object = fit_lgbm_workflow, file = "lgbm_workflow.rds")\nnew_lgbm_workflow <- readRDS(file = "lgbm_workflow.rds")\n\n# Predict - error: Attempting to use a Booster which no longer exists\n\npredict(new_lgbm_workflow, new_data = data_predict)\n\n\n\n### CASE 2: Save the workflow and the fitted model separately\n\nfitted_model <- (fit_lgbm_workflow %>% extract_fit_parsnip())$fit\nsaveRDS(object = fit_lgbm_workflow, file = "lgbm_workflow.rds")\nlightgbm::saveRDS.lgb.Booster(object = fitted_model, file = "lgbm_model.rds")\n\n\nnew_lgbm_workflow <- readRDS(file = "lgbm_workflow.rds")\nnew_lgbm_model <- lightgbm::readRDS.lgb.Booster(file = "lgbm_model.rds")\nnew_lgbm_workflow$fit$fit <- new_lgbm_model\n\n\n# Predict - error: cannot predict on data of class \xe2\x80\x98tbl_df\xe2\x80\x99\xe2\x80\x98tbl\xe2\x80\x99\xe2\x80\x98data.frame\xe2\x80\x99\n\npredict(new_lgbm_workflow, new_data = data_predict)\n
Run Code Online (Sandbox Code Playgroud)\n

只有带有lightgbm模型的工作流程似乎存在此问题。对于其他类型的模型(随机森林、xgboost、glm 等),我可以使用 保存拟合的工作流程saveRDS(),使用readRDS()并使用新数据进行预测

\n

对于情况 2,显然底层预测函数将更改为predict.lgb.Booster(),它以 amatrix作为输入。但我的 id 变量具有character格式,而 a 中的所有列matrix必须具有相同的格式

\n

有没有办法保存整个workflow以供将来使用?

\n

nat*_*e-m 1

我想出了一个解决方案来保存 lightgbm 以供将来参考。它不使用 tidymodel 框架,而是强制您首先将其转换为 lightgbm 模型格式。如果你想评估变量的重要性,也是同样的道理。

基于上面的代码:

# Convert to lightgbm booster model
lgb_model <- parsnip::extract_fit_engine(fit_lgbm_workflow) 

# If you want you can now evaluate variable importance. 
# Tidymodels does not support variable importance of lgb via bonsai currently

loss_varimp <- lgb_model %>%
    lgb.importance(.) 

# Save the booster out
lightgbm::lgb.save(lgb_model, filename_x)

# Read the booster in
lightgbm::lgb.load(filename_x)
Run Code Online (Sandbox Code Playgroud)

我还没弄清楚是否可以将加载的 lightgbm 合并回 tidymodel 格式,但现在您至少可以预测、使用和评估,而不必每次都重新运行模型。希望这会有所帮助,如果您找到更干净/更新的解决方案,请发布!