创建用于花絮对象的塑形图

发布于 2025-01-17 17:36:58 字数 2144 浏览 5 评论 0 原文

这个问题是指中的tidymodels模型。鉴于下面的评论,OP找到了解决方案,但到目前为止没有与社区共享。

我想分析装有 tidymodels 包装的树的合奏,带有形状值图,例如

“

并总结所有功能的效果我的数据集中的数据集,例如

“在此处输入图像描述”

dalextra 提供了一个函数来创建tidymodels divell.tidymodels() force_plot fastshap 软件包提供了一个包装器,用于基础python软件包的绘图函数 shap 。但是我不明白如何使该函数与 divell.tidymodels()函数的输出一起使用。

问题:如何使用 tidymodels dimend.tidymodels 在R中生成此类塑形图?

MWE(用于的Shap值> divell.tidymodels

library(MASS)
library(tidyverse)
library(tidymodels)
library(parsnip)
library(treesnip)
library(catboost)
library(fastshap)
library(DALEXtra)
set.seed(1337)
rec <-  recipe(crim ~ ., data = Boston)

split <- initial_split(Boston)

train_data <- training(split)

test_data <- testing(split) %>% dplyr::select(-crim) %>% as.matrix()

model_default<-
  parsnip::boost_tree(
    mode = "regression"
  ) %>%
  set_engine(engine = 'catboost', loss_function = 'RMSE')
#sometimes catboost is not loaded correctly the following two lines
#ensure prevent fitting errors
#https://github.com/curso-r/treesnip/issues/21 error is mentioned on last post
set_dependency("boost_tree", eng = "catboost", "catboost")
set_dependency("boost_tree", eng = "catboost", "treesnip")

model_fit_wf <- model_fit_wf <- workflow() %>% add_model(model_tune) %>%  add_recipe(rec) %>% {parsnip::fit(object = ., data =  train_data)}

SHAP_wf <- explain_tidymodels(model_fit_wf, data = X, y = train_data$crim, new_data = test_data

This question refers to Obtaining summary shap plot for catboost model with tidymodels in R. Given the comment below the question, the OP found a solution but did not share it with the community so far.

I want to analyze my tree ensembles fitted with the tidymodels package with SHAP value plots such as plots for single observations like

ttps://prnt.sc/CO_PC4aDUQA0

and to summarize the effect of all features of my dataset like

enter image description here

DALEXtra provides a function to create SHAP values for tidymodels explain.tidymodels(). force_plot from the fastshap package provide a wrapper for the plot function of the underlying python package SHAP. But I can't understand how to make the function work with the output of the explain.tidymodels() function.

Question : How can one generate such SHAP plots in R using tidymodels and explain.tidymodels?

MWE (for SHAP values with explain.tidymodels)

library(MASS)
library(tidyverse)
library(tidymodels)
library(parsnip)
library(treesnip)
library(catboost)
library(fastshap)
library(DALEXtra)
set.seed(1337)
rec <-  recipe(crim ~ ., data = Boston)

split <- initial_split(Boston)

train_data <- training(split)

test_data <- testing(split) %>% dplyr::select(-crim) %>% as.matrix()

model_default<-
  parsnip::boost_tree(
    mode = "regression"
  ) %>%
  set_engine(engine = 'catboost', loss_function = 'RMSE')
#sometimes catboost is not loaded correctly the following two lines
#ensure prevent fitting errors
#https://github.com/curso-r/treesnip/issues/21 error is mentioned on last post
set_dependency("boost_tree", eng = "catboost", "catboost")
set_dependency("boost_tree", eng = "catboost", "treesnip")

model_fit_wf <- model_fit_wf <- workflow() %>% add_model(model_tune) %>%  add_recipe(rec) %>% {parsnip::fit(object = ., data =  train_data)}

SHAP_wf <- explain_tidymodels(model_fit_wf, data = X, y = train_data$crim, new_data = test_data

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

哀由 2025-01-24 17:36:58

也许这会有所帮助。至少,这是朝正确方向迈出的一步。

首先,确保您安装了快速塑料和网状(即install.packages(“ ...”))。接下来,设置虚拟环境并安装形状(PIP安装...)。另外,为依赖关系图安装matplotlib 3.2.2(请在此查看GitHub问题 - 较旧的Matplotlib版本是必需的)。

Rstudio在虚拟环境设置方面有很好的信息。也就是说,虚拟环境设置需要或多或少的故障排除,具体取决于使用的IDE。 (可悲的是,某些工作设置限制了由于许可而引起的开源rstudio的使用。)

图书馆的文档(fastshap)在这方面也很有帮助。

这是LightGBM的工作流程(来自Treemnip Docs,经过轻微修改)。

library(tidymodels)
library(treesnip)

data("diamonds", package = "ggplot2")
diamonds <- diamonds %>% sample_n(1000)

# vfold resamples
diamonds_splits <- vfold_cv(diamonds, v = 5)

model_spec <- boost_tree(mtry = 5, trees = 500) %>% set_mode("regression")

# model specs
lightgbm_model <- model_spec %>% 
    set_engine("lightgbm", nthread = 6)

#workflows
lightgbm_wf <- workflow() %>% 
    add_model(
       lightgbm_model
    )

rec_ordered <- recipe(
    price ~ .
      , data = diamonds
) 

lightgbm_fit_ordered <- fit_resamples(
  add_recipe(
    lightgbm_wf, rec_ordered
    ), resamples = diamonds_splits)

在预测之前,我们希望适合我们的工作流程

fit_workflow <- lightgbm_wf %>% 
     add_recipe(rec_ordered) %>% 
     fit(data = diamonds)

,现在我们有一个合适的工作流程,并且可以预测。要使用FastShap ::解释功能,我们需要创建一个预测函数(这并不总是存在:根据所使用的引擎,它可能会或可能无法奏效 - 请参阅文档)。

predict_function_gbm <-  function(model, newdata) {
    predict(model, newdata) %>% pluck(.,1)
}

在我们使用时,让我们获取平均预测值(下面使用)。这也是确保功能运行的检查。

mean_preds <- mean(
    predict_function_gbm(
       fit_workflow, diamonds %>% select(-price)
   )
)

现在,我们创建我们的解释(Shap值)。在此处注意Pred_wrapper和X参数(有关其他示例,请参见FastShap GitHub问题 - 即Glmnet)。

fastshap::explain( 
    fit_workflow, 
    X = as.data.frame(diamonds %>% select(-price)),
    pred_wrapper = predict_function_gbm, 
    nsim = 10
) -> explanations_gbm

这应该产生力图。

fastshap::force_plot(
    object = explanations_gbm[1,], 
    feature_values = as.data.frame(diamonds %>% select(-price))[1,], 
    display = "viewer", 
    baseline = mean_preds) 

这允许多个垂直堆叠:

fastshap::force_plot(
    object = explanations_gbm[1:20,], 
    feature_values = as.data.frame(diamonds %>% select(-price))[1:20,], 
    display = "viewer", 
    baseline = mean_preds) 

添加link =“ logit”进行分类。将显示为“ html”以进行rmarkDown渲染。

现在以获取摘要图和依赖图。

诀窍是使用网状直接访问功能。请注意,对于依赖项图,诸如变形金刚,numpy等库的逻辑保留

library(reticulate)
shap = import("shap")
np = import("numpy") 

shap$dependence_plot(
     "rank(3)", 
     data.matrix(explanations_gbm),
     data.matrix(diamond %>% select(-price))
)

有关等级的说明(3) - 等级(1)等也将起作用。

令人难以置信的是,当我尝试直接命名该功能(即“剪切”)时,它丢了错误。

现在以摘要图:

shap$summary_plot( 
    data.matrix(explanations_gbm),
    data.matrix(diamond %>% select(-price))
)

最终注意:反复渲染该图将产生错误的可视化。希望这为Catboost可视化提供了depture点。

Perhaps this will help. At the very least, it is a step in the right direction.

First, ensure you have fastshap and reticulate installed (i.e., install.packages("...")). Next, set up a virtual environment and install shap (pip install ...). Also, install matplotlib 3.2.2 for the dependency plots (check out GitHub issues on this -- an older version of matplotlib is necessary).

RStudio has great information on virtual environment setup. That said, virtual environment setup requires more or less troubleshooting depending on the IDE of use. (Sadly, some work settings restrict the use of open source RStudio due to licensing.)

Docs for library(fastshap) are also helpful on this front.

Here's a workflow for lightgbm (from treesnip docs, lightly modified).

library(tidymodels)
library(treesnip)

data("diamonds", package = "ggplot2")
diamonds <- diamonds %>% sample_n(1000)

# vfold resamples
diamonds_splits <- vfold_cv(diamonds, v = 5)

model_spec <- boost_tree(mtry = 5, trees = 500) %>% set_mode("regression")

# model specs
lightgbm_model <- model_spec %>% 
    set_engine("lightgbm", nthread = 6)

#workflows
lightgbm_wf <- workflow() %>% 
    add_model(
       lightgbm_model
    )

rec_ordered <- recipe(
    price ~ .
      , data = diamonds
) 

lightgbm_fit_ordered <- fit_resamples(
  add_recipe(
    lightgbm_wf, rec_ordered
    ), resamples = diamonds_splits)

Prior to prediction we want to fit our workflow

fit_workflow <- lightgbm_wf %>% 
     add_recipe(rec_ordered) %>% 
     fit(data = diamonds)

Now we have a fit workflow and can predict. To use the fastshap::explain function, we need to create a predict function (this doesn't always hold: depending on the engine used it may or may not work out of the box -- see docs).

predict_function_gbm <-  function(model, newdata) {
    predict(model, newdata) %>% pluck(.,1)
}

Let's get the mean prediction value (used below) while we're at it. This also serves as a check to ensure the function is functioning.

mean_preds <- mean(
    predict_function_gbm(
       fit_workflow, diamonds %>% select(-price)
   )
)

Now we create our explanations (shap values). Note the pred_wrapper and X arguments here (see fastshap github issues for other examples -- i.e. glmnet).

fastshap::explain( 
    fit_workflow, 
    X = as.data.frame(diamonds %>% select(-price)),
    pred_wrapper = predict_function_gbm, 
    nsim = 10
) -> explanations_gbm

This should produce a force plot.

fastshap::force_plot(
    object = explanations_gbm[1,], 
    feature_values = as.data.frame(diamonds %>% select(-price))[1,], 
    display = "viewer", 
    baseline = mean_preds) 

This allows multiple, vertically stacked:

fastshap::force_plot(
    object = explanations_gbm[1:20,], 
    feature_values = as.data.frame(diamonds %>% select(-price))[1:20,], 
    display = "viewer", 
    baseline = mean_preds) 

Add link = "logit" for classification. Change display to "html" for Rmarkdown rendering.

Now for summary plots and dependency plots.

The trick is using reticulate to access the functions directly. Note that the same logic hold for libraries like transformers, numpy, etc.

First, for dependency plot.

library(reticulate)
shap = import("shap")
np = import("numpy") 

shap$dependence_plot(
     "rank(3)", 
     data.matrix(explanations_gbm),
     data.matrix(diamond %>% select(-price))
)

See shap docs for explanation of rank(3) -- rank(1) etc will also work.

Unforunately it threw an error when I attempted naming the feature directly (i.e., "cut").

Now for the summary plot:

shap$summary_plot( 
    data.matrix(explanations_gbm),
    data.matrix(diamond %>% select(-price))
)

Final note: rendering the plot repeatedly will produce buggy visualizations. Hopefully this provides a point of depature for catboost visualizations.

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文