如何使用 tidymodels 和工作流集在同一数据集上拟合多个不同的线性模型

发布于 2025-01-09 06:01:33 字数 2037 浏览 0 评论 0原文

我想评估同一数据集上多个(主要是)线性回归模型的性能。我想也许使用 tidymodels 包和 workflowsets::workflow_set() 可能会起作用。我按照此处的示例进行操作,但我无法弄清楚如何从代码中实际获得合适的结果。

# Load packages
  library("tidyverse")
  library('workflowsets')
  library("parsnip")
  library("recipes")

# Data
  dat <- 
    structure(list(q = c(66.65, 75.58, 83.06, 91.28, 119.26, 133.14, 
    146.32, 153.39, 168.57, 182.36, 210.09, 188.19, 213.42, 296.95, 
    326.33, 358.63, 475.99, 475.99, 683.44, 683.44, 838.49, 1282.1, 
    1648.97, 1572.97, 2055.14, 2521.39, 2685.11, 2859.46, 3242.87, 
    6899.19, 6377.42, 7581.96, 9599.32), c = c(317.06, 283.99, 279.56, 
    283.99, 227.84, 227.84, 262.5, 242.64, 270.9, 266.67, 210.6, 
    235.12, 235.12, 210.6, 207.31, 227.84, 220.78, 194.67, 177.13, 
    207.31, 179.94, 177.13, 182.79, 139.89, 148.98, 144.36, 137.71, 
    158.66, 142.11, 142.11, 119.52, 110.48, 158.66), c_less_c_nought = c(300.06, 
    266.99, 262.56, 266.99, 210.84, 210.84, 245.5, 225.64, 253.9, 
    249.67, 193.6, 218.12, 218.12, 193.6, 190.31, 210.84, 203.78, 
    177.67, 160.13, 190.31, 162.94, 160.13, 165.79, 122.89, 131.98, 
    127.36, 120.71, 141.66, 125.11, 125.11, 102.52, 93.48, 141.66
    )), row.names = c(NA, -33L), class = c("tbl_df", "tbl", "data.frame"
    )) 

  # Recipes for models
  eq1_mod1_recipe <-
    recipes::recipe(c ~ q, data = dat) %>% 
    step_log(c, q, base = 10)
  
  eq2_mod2_a_recipe <- 
    recipes::recipe(c_less_c_nought ~ q, data = dat) %>% 
    step_log(c_less_c_nought, q, base = 10)  

  # Define model types
  lm_model <-
    parsnip::linear_reg() %>% 
    parsnip::set_engine("lm") %>% 
    parsnip::set_mode("regression")

  # Run the models?
  cq_models <-
    workflowsets::workflow_set(
      preproc = list(eq1m1 = eq1_mod1_recipe, e2m2a = eq2_mod2_a_recipe),
      models = list(lm = lm_model)
    )

看来这实际上并不适合模型本身。我需要在什么/哪里添加代码才能适应线性模型?

或者,是否有更好但仍然“整洁”的方法来做到这一点?接受建议。

I want to evaluate the performance of several (mostly) linear regression models on the same dataset. I thought maybe using tidymodels packages along with the workflowsets::workflow_set() might work. I followed the example here, but I cannot figure out how to actually get fit results from the code.

# Load packages
  library("tidyverse")
  library('workflowsets')
  library("parsnip")
  library("recipes")

# Data
  dat <- 
    structure(list(q = c(66.65, 75.58, 83.06, 91.28, 119.26, 133.14, 
    146.32, 153.39, 168.57, 182.36, 210.09, 188.19, 213.42, 296.95, 
    326.33, 358.63, 475.99, 475.99, 683.44, 683.44, 838.49, 1282.1, 
    1648.97, 1572.97, 2055.14, 2521.39, 2685.11, 2859.46, 3242.87, 
    6899.19, 6377.42, 7581.96, 9599.32), c = c(317.06, 283.99, 279.56, 
    283.99, 227.84, 227.84, 262.5, 242.64, 270.9, 266.67, 210.6, 
    235.12, 235.12, 210.6, 207.31, 227.84, 220.78, 194.67, 177.13, 
    207.31, 179.94, 177.13, 182.79, 139.89, 148.98, 144.36, 137.71, 
    158.66, 142.11, 142.11, 119.52, 110.48, 158.66), c_less_c_nought = c(300.06, 
    266.99, 262.56, 266.99, 210.84, 210.84, 245.5, 225.64, 253.9, 
    249.67, 193.6, 218.12, 218.12, 193.6, 190.31, 210.84, 203.78, 
    177.67, 160.13, 190.31, 162.94, 160.13, 165.79, 122.89, 131.98, 
    127.36, 120.71, 141.66, 125.11, 125.11, 102.52, 93.48, 141.66
    )), row.names = c(NA, -33L), class = c("tbl_df", "tbl", "data.frame"
    )) 

  # Recipes for models
  eq1_mod1_recipe <-
    recipes::recipe(c ~ q, data = dat) %>% 
    step_log(c, q, base = 10)
  
  eq2_mod2_a_recipe <- 
    recipes::recipe(c_less_c_nought ~ q, data = dat) %>% 
    step_log(c_less_c_nought, q, base = 10)  

  # Define model types
  lm_model <-
    parsnip::linear_reg() %>% 
    parsnip::set_engine("lm") %>% 
    parsnip::set_mode("regression")

  # Run the models?
  cq_models <-
    workflowsets::workflow_set(
      preproc = list(eq1m1 = eq1_mod1_recipe, e2m2a = eq2_mod2_a_recipe),
      models = list(lm = lm_model)
    )

It appears this doesn't actually fit the models themselves. What/where do I need to add code to also fit the linear models?

Alternatively, is there a better, but still "tidy" way to do this? Open to recommendations.

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

堇年纸鸢 2025-01-16 06:01:34

工作流程仅适用于重新采样的数据,例如交叉验证或引导折叠。这是设计使然,这样人们就不会尝试比较将单个时间拟合到数据集的性能指标。

library("tidymodels")

# Data
dat <- 
  structure(list(q = c(66.65, 75.58, 83.06, 91.28, 119.26, 133.14, 
                       146.32, 153.39, 168.57, 182.36, 210.09, 188.19, 213.42, 296.95, 
                       326.33, 358.63, 475.99, 475.99, 683.44, 683.44, 838.49, 1282.1, 
                       1648.97, 1572.97, 2055.14, 2521.39, 2685.11, 2859.46, 3242.87, 
                       6899.19, 6377.42, 7581.96, 9599.32), c = c(317.06, 283.99, 279.56, 
                                                                  283.99, 227.84, 227.84, 262.5, 242.64, 270.9, 266.67, 210.6, 
                                                                  235.12, 235.12, 210.6, 207.31, 227.84, 220.78, 194.67, 177.13, 
                                                                  207.31, 179.94, 177.13, 182.79, 139.89, 148.98, 144.36, 137.71, 
                                                                  158.66, 142.11, 142.11, 119.52, 110.48, 158.66), c_less_c_nought = c(300.06, 
                                                                                                                                       266.99, 262.56, 266.99, 210.84, 210.84, 245.5, 225.64, 253.9, 
                                                                                                                                       249.67, 193.6, 218.12, 218.12, 193.6, 190.31, 210.84, 203.78, 
                                                                                                                                       177.67, 160.13, 190.31, 162.94, 160.13, 165.79, 122.89, 131.98, 
                                                                                                                                       127.36, 120.71, 141.66, 125.11, 125.11, 102.52, 93.48, 141.66
                                                                  )), row.names = c(NA, -33L), class = c("tbl_df", "tbl", "data.frame"
                                                                  )) 


folds <- bootstraps(dat, times = 10)

eq1_mod1_recipe <-
  recipe(c ~ q, data = dat) %>% 
  step_log(c, q, base = 10)

eq2_mod2_a_recipe <- 
  recipe(c_less_c_nought ~ q, data = dat) %>% 
  step_log(c_less_c_nought, q, base = 10)  

lm_model <- linear_reg()

res <-
  workflow_set(
    preproc = list(eq1m1 = eq1_mod1_recipe, e2m2a = eq2_mod2_a_recipe),
    models = list(lm = lm_model)
  ) %>%
  workflow_map("fit_resamples", resamples = folds)


collect_metrics(res)
#> # A tibble: 4 × 9
#>   wflow_id .config         preproc model .metric .estimator   mean     n std_err
#>   <chr>    <chr>           <chr>   <chr> <chr>   <chr>       <dbl> <int>   <dbl>
#> 1 eq1m1_lm Preprocessor1_… recipe  line… rmse    standard   0.0454    10 0.00214
#> 2 eq1m1_lm Preprocessor1_… recipe  line… rsq     standard   0.857     10 0.0220 
#> 3 e2m2a_lm Preprocessor1_… recipe  line… rmse    standard   0.0502    10 0.00245
#> 4 e2m2a_lm Preprocessor1_… recipe  line… rsq     standard   0.856     10 0.0221

reprex 包 (v2.0.1) 创建于 2022 年 2 月 23 日

workflowsets only work with resampled data, like cross-validation or bootstrap folds. This is by design, so that folks don't try to compare performance metrics from fitting a single time to a dataset.

library("tidymodels")

# Data
dat <- 
  structure(list(q = c(66.65, 75.58, 83.06, 91.28, 119.26, 133.14, 
                       146.32, 153.39, 168.57, 182.36, 210.09, 188.19, 213.42, 296.95, 
                       326.33, 358.63, 475.99, 475.99, 683.44, 683.44, 838.49, 1282.1, 
                       1648.97, 1572.97, 2055.14, 2521.39, 2685.11, 2859.46, 3242.87, 
                       6899.19, 6377.42, 7581.96, 9599.32), c = c(317.06, 283.99, 279.56, 
                                                                  283.99, 227.84, 227.84, 262.5, 242.64, 270.9, 266.67, 210.6, 
                                                                  235.12, 235.12, 210.6, 207.31, 227.84, 220.78, 194.67, 177.13, 
                                                                  207.31, 179.94, 177.13, 182.79, 139.89, 148.98, 144.36, 137.71, 
                                                                  158.66, 142.11, 142.11, 119.52, 110.48, 158.66), c_less_c_nought = c(300.06, 
                                                                                                                                       266.99, 262.56, 266.99, 210.84, 210.84, 245.5, 225.64, 253.9, 
                                                                                                                                       249.67, 193.6, 218.12, 218.12, 193.6, 190.31, 210.84, 203.78, 
                                                                                                                                       177.67, 160.13, 190.31, 162.94, 160.13, 165.79, 122.89, 131.98, 
                                                                                                                                       127.36, 120.71, 141.66, 125.11, 125.11, 102.52, 93.48, 141.66
                                                                  )), row.names = c(NA, -33L), class = c("tbl_df", "tbl", "data.frame"
                                                                  )) 


folds <- bootstraps(dat, times = 10)

eq1_mod1_recipe <-
  recipe(c ~ q, data = dat) %>% 
  step_log(c, q, base = 10)

eq2_mod2_a_recipe <- 
  recipe(c_less_c_nought ~ q, data = dat) %>% 
  step_log(c_less_c_nought, q, base = 10)  

lm_model <- linear_reg()

res <-
  workflow_set(
    preproc = list(eq1m1 = eq1_mod1_recipe, e2m2a = eq2_mod2_a_recipe),
    models = list(lm = lm_model)
  ) %>%
  workflow_map("fit_resamples", resamples = folds)


collect_metrics(res)
#> # A tibble: 4 × 9
#>   wflow_id .config         preproc model .metric .estimator   mean     n std_err
#>   <chr>    <chr>           <chr>   <chr> <chr>   <chr>       <dbl> <int>   <dbl>
#> 1 eq1m1_lm Preprocessor1_… recipe  line… rmse    standard   0.0454    10 0.00214
#> 2 eq1m1_lm Preprocessor1_… recipe  line… rsq     standard   0.857     10 0.0220 
#> 3 e2m2a_lm Preprocessor1_… recipe  line… rmse    standard   0.0502    10 0.00245
#> 4 e2m2a_lm Preprocessor1_… recipe  line… rsq     standard   0.856     10 0.0221

Created on 2022-02-23 by the reprex package (v2.0.1)

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文