tidymodels:如何从培训数据中提取重要性

发布于 2025-02-05 04:18:42 字数 1243 浏览 4 评论 0 原文

我有以下代码,在其中进行一些网格搜索以寻找不同的MTRY和MIN_N。我知道如何提取具有最高精度的参数(请参见第二个代码框)。如何在培训数据集中提取每个功能的重要性?我在网上找到的指南仅在测试数据集中使用“ last_fit”来显示如何进行。例如指南: https://wwww.tidymodels.orgg/start/case/case/case/case/case - 研究/#数据分割

set.seed(seed_number)
    data_split <- initial_split(node_strength,prop = 0.8,strata = Group)
    
    train <- training(data_split)
    test <- testing(data_split)
    train_folds <- vfold_cv(train,v = 10)
    
    
    rfc <- rand_forest(mode = "classification", mtry = tune(),
                       min_n = tune(), trees = 1500) %>%
        set_engine("ranger", num.threads = 48, importance = "impurity")
    
    rfc_recipe <- recipe(data = train, Group~.)
    
    rfc_workflow <- workflow() %>% add_model(rfc) %>%
        add_recipe(rfc_recipe)
    
    rfc_result <- rfc_workflow %>%
        tune_grid(train_folds, grid = 40, control = control_grid(save_pred = TRUE),
                  metrics = metric_set(accuracy))

best <- 
        rfc_result %>% 
        select_best(metric = "accuracy")

I have the following code, where I do some grid search for different mtry and min_n. I know how to extract the parameters that give the highest accuracy (see second code box). How can I extract the importance of each feature in the training dataset? The guides I found online show how to do it only in the test dataset using "last_fit". E.g. of guide: https://www.tidymodels.org/start/case-study/#data-split

set.seed(seed_number)
    data_split <- initial_split(node_strength,prop = 0.8,strata = Group)
    
    train <- training(data_split)
    test <- testing(data_split)
    train_folds <- vfold_cv(train,v = 10)
    
    
    rfc <- rand_forest(mode = "classification", mtry = tune(),
                       min_n = tune(), trees = 1500) %>%
        set_engine("ranger", num.threads = 48, importance = "impurity")
    
    rfc_recipe <- recipe(data = train, Group~.)
    
    rfc_workflow <- workflow() %>% add_model(rfc) %>%
        add_recipe(rfc_recipe)
    
    rfc_result <- rfc_workflow %>%
        tune_grid(train_folds, grid = 40, control = control_grid(save_pred = TRUE),
                  metrics = metric_set(accuracy))

.

best <- 
        rfc_result %>% 
        select_best(metric = "accuracy")

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

榆西 2025-02-12 04:18:42

为此,您将需要创建一个自定义提取函数,例如在本文档中概述了

对于随机森林的重要性,您的功能将看起来像这样:

get_rf_imp <- function(x) {
    x %>% 
        extract_fit_parsnip() %>% 
        vip::vi()
}

然后您可以将其应用于这样的重塑(请注意,您将获得新的 .entracts .entracts 列):

library(tidymodels)
data(cells, package = "modeldata")

set.seed(123)
cell_split <- cells %>% select(-case) %>%
    initial_split(strata = class)
cell_train <- training(cell_split)
cell_test  <- testing(cell_split)
folds <- vfold_cv(cell_train)            

rf_spec <- rand_forest(mode = "classification") %>%
    set_engine("ranger", importance = "impurity")

ctrl_imp <- control_grid(extract = get_rf_imp)

cells_res <-
    workflow(class ~ ., rf_spec) %>%
    fit_resamples(folds, control = ctrl_imp)
cells_res
#> # Resampling results
#> # 10-fold cross-validation 
#> # A tibble: 10 × 5
#>    splits             id     .metrics         .notes           .extracts       
#>    <list>             <chr>  <list>           <list>           <list>          
#>  1 <split [1362/152]> Fold01 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  2 <split [1362/152]> Fold02 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  3 <split [1362/152]> Fold03 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  4 <split [1362/152]> Fold04 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  5 <split [1363/151]> Fold05 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  6 <split [1363/151]> Fold06 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  7 <split [1363/151]> Fold07 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  8 <split [1363/151]> Fold08 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  9 <split [1363/151]> Fold09 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#> 10 <split [1363/151]> Fold10 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>

在2022上创建-06-19由 preprex package (v2.0.1)

您具有这些可变的重要性分数提取,您可以 unnest()他们(现在,您必须两次这样做,因为它是深嵌套的),然后您可以按照自己的要求进行汇总和可视化:

cells_res %>%
    select(id, .extracts) %>%
    unnest(.extracts) %>%
    unnest(.extracts) %>%
    group_by(Variable) %>%
    summarise(Mean = mean(Importance),
              Variance = sd(Importance)) %>%
    slice_max(Mean, n = 15) %>%
    ggplot(aes(Mean, reorder(Variable, Mean))) +
    geom_crossbar(aes(xmin = Mean - Variance, xmax = Mean + Variance)) +
    labs(x = "Variable importance", y = NULL)

”“

在2022-06-19创建的(v2.0.1)

To do this, you will want to create a custom extract function, as outlined in this documentation.

For random forest variable importance, your function will look something like this:

get_rf_imp <- function(x) {
    x %>% 
        extract_fit_parsnip() %>% 
        vip::vi()
}

And then you can apply it to your resamples like so (notice that you get a new .extracts column):

library(tidymodels)
data(cells, package = "modeldata")

set.seed(123)
cell_split <- cells %>% select(-case) %>%
    initial_split(strata = class)
cell_train <- training(cell_split)
cell_test  <- testing(cell_split)
folds <- vfold_cv(cell_train)            

rf_spec <- rand_forest(mode = "classification") %>%
    set_engine("ranger", importance = "impurity")

ctrl_imp <- control_grid(extract = get_rf_imp)

cells_res <-
    workflow(class ~ ., rf_spec) %>%
    fit_resamples(folds, control = ctrl_imp)
cells_res
#> # Resampling results
#> # 10-fold cross-validation 
#> # A tibble: 10 × 5
#>    splits             id     .metrics         .notes           .extracts       
#>    <list>             <chr>  <list>           <list>           <list>          
#>  1 <split [1362/152]> Fold01 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  2 <split [1362/152]> Fold02 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  3 <split [1362/152]> Fold03 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  4 <split [1362/152]> Fold04 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  5 <split [1363/151]> Fold05 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  6 <split [1363/151]> Fold06 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  7 <split [1363/151]> Fold07 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  8 <split [1363/151]> Fold08 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  9 <split [1363/151]> Fold09 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#> 10 <split [1363/151]> Fold10 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>

Created on 2022-06-19 by the reprex package (v2.0.1)

Once you have those variable importance score extracts, you can unnest() them (right now, you have to do this twice because it is deeply nested) and then you can summarize and visualize as you prefer:

cells_res %>%
    select(id, .extracts) %>%
    unnest(.extracts) %>%
    unnest(.extracts) %>%
    group_by(Variable) %>%
    summarise(Mean = mean(Importance),
              Variance = sd(Importance)) %>%
    slice_max(Mean, n = 15) %>%
    ggplot(aes(Mean, reorder(Variable, Mean))) +
    geom_crossbar(aes(xmin = Mean - Variance, xmax = Mean + Variance)) +
    labs(x = "Variable importance", y = NULL)

Created on 2022-06-19 by the reprex package (v2.0.1)

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文