使用MLR3Proba模型的Survxai解释器
我正在尝试通过使用MLR3Proba构建的生存模型来构建Survxai解释器。我很难为解释器创建必要的预测函数。有没有人尝试过这样的东西?
到目前为止,我的代码如下:
require(survxai)
require(survival)
require(survivalmodels)
require(mlr3proba)
require(mlr3pipelines)
create_pipeops <- function(learner) {
GraphLearner$new(po("encode") %>>% po("scale") %>>% po("learner", learner))
}
fit<-lrn("surv.deepsurv")
fit<-create_pipeops(fit)
data<-veteran
survival_task<-TaskSurv$new("veteran", veteran, time = "time", event = "status")
fit$train(survival_task)
predict_function<-function(model, newdata, times=NULL){
if(!is.data.frame(newdata)){
newdata <- data.frame(newdata)
}
surv_task<-TaskSurv$new("task", newdata, time = "time",
event = "status")
pred<-model$predict(surv_task)
mat<-matrix(pred$data$distr, nrow = nrow(pred$data$distr))
colnames(mat)<-colnames(pred$data$distr)
return(mat)
}
explainer<-survxai::explain(model = learner$model, data = veteran[,-c(3,4)],
y = Surv(veteran$time, veteran$status),
predict_function = predict_function)
pred_breakdown<-prediction_breakdown(explainer, veteran[1,])
它会引发以下错误:中的错误[。data.table
(r6_private(backend)$。data,event,event,with = false): 找不到列:状态,但我怀疑一旦解决了一个问题,可能会有更多。我不完全了解该函数返回的对象的结构。
在preadive_function中,我包括times
参数,因为根据r帮助页面,该函数必须获取三个参数。
I am trying to build a survxai explainer from a survival model built with mlr3proba. I'm having trouble creating the predict_function necessary for the explainer. Has anyone ever tried to build something like this?
So far, my code is the following:
require(survxai)
require(survival)
require(survivalmodels)
require(mlr3proba)
require(mlr3pipelines)
create_pipeops <- function(learner) {
GraphLearner$new(po("encode") %>>% po("scale") %>>% po("learner", learner))
}
fit<-lrn("surv.deepsurv")
fit<-create_pipeops(fit)
data<-veteran
survival_task<-TaskSurv$new("veteran", veteran, time = "time", event = "status")
fit$train(survival_task)
predict_function<-function(model, newdata, times=NULL){
if(!is.data.frame(newdata)){
newdata <- data.frame(newdata)
}
surv_task<-TaskSurv$new("task", newdata, time = "time",
event = "status")
pred<-model$predict(surv_task)
mat<-matrix(pred$data$distr, nrow = nrow(pred$data$distr))
colnames(mat)<-colnames(pred$data$distr)
return(mat)
}
explainer<-survxai::explain(model = learner$model, data = veteran[,-c(3,4)],
y = Surv(veteran$time, veteran$status),
predict_function = predict_function)
pred_breakdown<-prediction_breakdown(explainer, veteran[1,])
It throws the following error: Error in [.data.table
(r6_private(backend)$.data, , event, with = FALSE) :
column(s) not found: status, but I suspect that once that one is solved there may be more. I don't fully understand the structure of the object that the function returns.
In the predict_function, I included the times
argument because according to the R help page, the function must take the three arguments.
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(1)
在此处使用RandomForestSrc的示例,您可以将
surv.rfsrc
更改为surch.deepsurv
for horpe> hove>。顺便说一句,我们正计划在MLR3Proba中实施此功能,否则我可能会直接将其添加到生存模型中,仍然决定!由
Working example with randomForestSRC here, you can just change
surv.rfsrc
tosurv.deepsurv
for your example. BTW we are planning on implementing this within mlr3proba soon, or I might just add it directly to survivalmodels, still deciding!Created on 2022-06-07 by the reprex package (v2.0.1)