使用MLR3Proba模型的Survxai解释器

发布于 2025-02-04 06:48:37 字数 1446 浏览 5 评论 0原文

我正在尝试通过使用MLR3Proba构建的生存模型来构建Survxai解释器。我很难为解释器创建必要的预测函数。有没有人尝试过这样的东西?

到目前为止,我的代码如下:

require(survxai)
require(survival)
require(survivalmodels)
require(mlr3proba)
require(mlr3pipelines)

create_pipeops <- function(learner) {
  GraphLearner$new(po("encode") %>>% po("scale") %>>% po("learner", learner))
}

fit<-lrn("surv.deepsurv")
fit<-create_pipeops(fit)

data<-veteran
survival_task<-TaskSurv$new("veteran", veteran, time = "time", event = "status")
fit$train(survival_task)

predict_function<-function(model, newdata, times=NULL){
  if(!is.data.frame(newdata)){
    newdata <- data.frame(newdata)
  }
  surv_task<-TaskSurv$new("task", newdata, time = "time", 
                          event = "status")
  pred<-model$predict(surv_task)
  mat<-matrix(pred$data$distr, nrow = nrow(pred$data$distr))
  colnames(mat)<-colnames(pred$data$distr)
  return(mat)
}

explainer<-survxai::explain(model = learner$model, data = veteran[,-c(3,4)],
                            y = Surv(veteran$time, veteran$status),
                            predict_function = predict_function)

pred_breakdown<-prediction_breakdown(explainer, veteran[1,])

它会引发以下错误:中的错误[。data.table(r6_private(backend)$。data,event,event,with = false): 找不到列:状态,但我怀疑一旦解决了一个问题,可能会有更多。我不完全了解该函数返回的对象的结构。

在preadive_function中,我包括times参数,因为根据r帮助页面,该函数必须获取三个参数。

I am trying to build a survxai explainer from a survival model built with mlr3proba. I'm having trouble creating the predict_function necessary for the explainer. Has anyone ever tried to build something like this?

So far, my code is the following:

require(survxai)
require(survival)
require(survivalmodels)
require(mlr3proba)
require(mlr3pipelines)

create_pipeops <- function(learner) {
  GraphLearner$new(po("encode") %>>% po("scale") %>>% po("learner", learner))
}

fit<-lrn("surv.deepsurv")
fit<-create_pipeops(fit)

data<-veteran
survival_task<-TaskSurv$new("veteran", veteran, time = "time", event = "status")
fit$train(survival_task)

predict_function<-function(model, newdata, times=NULL){
  if(!is.data.frame(newdata)){
    newdata <- data.frame(newdata)
  }
  surv_task<-TaskSurv$new("task", newdata, time = "time", 
                          event = "status")
  pred<-model$predict(surv_task)
  mat<-matrix(pred$data$distr, nrow = nrow(pred$data$distr))
  colnames(mat)<-colnames(pred$data$distr)
  return(mat)
}

explainer<-survxai::explain(model = learner$model, data = veteran[,-c(3,4)],
                            y = Surv(veteran$time, veteran$status),
                            predict_function = predict_function)

pred_breakdown<-prediction_breakdown(explainer, veteran[1,])

It throws the following error: Error in [.data.table(r6_private(backend)$.data, , event, with = FALSE) :
column(s) not found: status, but I suspect that once that one is solved there may be more. I don't fully understand the structure of the object that the function returns.

In the predict_function, I included the times argument because according to the R help page, the function must take the three arguments.

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

秋风の叶未落 2025-02-11 06:48:37

在此处使用RandomForestSrc的示例,您可以将surv.rfsrc更改为surch.deepsurv for horpe> hove>。顺便说一句,我们正计划在MLR3Proba中实施此功能,否则我可能会直接将其添加到生存模型中,仍然决定!

library(mlr3proba)
#> Loading required package: mlr3
#> Warning: package 'mlr3' was built under R version 4.1.3
library(mlr3extralearners)
#> 
#> Attaching package: 'mlr3extralearners'
#> The following objects are masked from 'package:mlr3':
#> 
#>     lrn, lrns
library(survxai)
#> Loading required package: prodlim
#> Welcome to survxai (version: 0.2.1).
#> Information about the package can be found in the GitHub repository: https://github.com/MI2DataLab/survxai
library(survival)
data(pbc, package = "randomForestSRC")
pbc <- pbc[complete.cases(pbc), ]
task <- as_task_surv(pbc, event = "status", time = "days")
split <- partition(task)
predict_times <- function(model, data, times) {
  t(model$predict_newdata(data)$distr$survival(times))
}
model <- lrn("surv.rfsrc")$train(task, row_ids = split$train)
surve_cph <- explain(
  model = model, data = pbc[, -c(1, 2)],
  y = Surv(pbc$days, pbc$status),
  predict_function = predict_times
)
prediction_breakdown(surve_cph, pbc[1, -c(1, 2)])
#>             contribution
#> bili            -35.079%
#> edema           -10.278%
#> ascites          -5.505%
#> copper           -1.084%
#> stage            -0.773%
#> prothrombin      -0.421%
#> albumin          -0.247%
#> sgot             -0.143%
#> hepatom          -0.098%
#> spiders          -0.086%
#> alk              -0.043%
#> trig             -0.041%
#> age              -0.035%

Working example with randomForestSRC here, you can just change surv.rfsrc to surv.deepsurv for your example. BTW we are planning on implementing this within mlr3proba soon, or I might just add it directly to survivalmodels, still deciding!

library(mlr3proba)
#> Loading required package: mlr3
#> Warning: package 'mlr3' was built under R version 4.1.3
library(mlr3extralearners)
#> 
#> Attaching package: 'mlr3extralearners'
#> The following objects are masked from 'package:mlr3':
#> 
#>     lrn, lrns
library(survxai)
#> Loading required package: prodlim
#> Welcome to survxai (version: 0.2.1).
#> Information about the package can be found in the GitHub repository: https://github.com/MI2DataLab/survxai
library(survival)
data(pbc, package = "randomForestSRC")
pbc <- pbc[complete.cases(pbc), ]
task <- as_task_surv(pbc, event = "status", time = "days")
split <- partition(task)
predict_times <- function(model, data, times) {
  t(model$predict_newdata(data)$distr$survival(times))
}
model <- lrn("surv.rfsrc")$train(task, row_ids = split$train)
surve_cph <- explain(
  model = model, data = pbc[, -c(1, 2)],
  y = Surv(pbc$days, pbc$status),
  predict_function = predict_times
)
prediction_breakdown(surve_cph, pbc[1, -c(1, 2)])
#>             contribution
#> bili            -35.079%
#> edema           -10.278%
#> ascites          -5.505%
#> copper           -1.084%
#> stage            -0.773%
#> prothrombin      -0.421%
#> albumin          -0.247%
#> sgot             -0.143%
#> hepatom          -0.098%
#> spiders          -0.086%
#> alk              -0.043%
#> trig             -0.041%
#> age              -0.035%

Created on 2022-06-07 by the reprex package (v2.0.1)

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文