如何从rpart软件包中绘制递归分区
我想绘制由递归二进制拆分构建的二维协变量空间的分区。更确切地说,我想编写一个复制以下图的函数(从,PAG):
购物车算法)。我要实现的是一个函数,该函数将输出 rpart
函数并生成此类图。
它遵循一些示例代码:
## Generating data.
set.seed(1975)
n <- 5000
p <- 2
X <- matrix(sample(seq(0, 1, by = 0.01), n * p, replace = TRUE), ncol = p)
Y <- X[, 1] + 2 * X[, 2] + rnorm(n)
## Building tree.
tree <- rpart(Y ~ ., data = data.frame(Y, X), method = "anova", control = rpart.control(cp = 0, maxdepth = 2))
rpart_splits <- function(fit, digits = getOption("digits")) {
splits <- fit$splits
if (!is.null(splits)) {
ff <- fit$frame
is.leaf <- ff$var == "<leaf>"
n <- nrow(splits)
nn <- ff$ncompete + ff$nsurrogate + !is.leaf
ix <- cumsum(c(1L, nn))
ix_prim <- unlist(mapply(ix, ix + c(ff$ncompete, 0), FUN = seq, SIMPLIFY = F))
type <- rep.int("surrogate", n)
type[ix_prim[ix_prim <= n]] <- "primary"
type[ix[ix <= n]] <- "main"
left <- character(nrow(splits))
side <- splits[, 2L]
for (i in seq_along(left)) {
left[i] <- if (side[i] == -1L)
paste("<", format(signif(splits[i, 4L], digits)))
else if (side[i] == 1L)
paste(">=", format(signif(splits[i, 4L], digits)))
else {
catside <- fit$csplit[splits[i, 4L], 1:side[i]]
paste(c("L", "-", "R")[catside], collapse = "", sep = "")
}
}
cbind(data.frame(var = rownames(splits),
type = type,
node = rep(as.integer(row.names(ff)), times = nn),
ix = rep(seq_len(nrow(ff)), nn),
left = left),
as.data.frame(splits, row.names = F))
}
}
使用此函数,我能够恢复所有分裂变量和点:
splits <- rpart_splits(tree)[rpart_splits(tree)$type == "main", ]
splits
# var type node ix left count ncat improve index adj
# 1 X2 main 1 1 < 0.565 5000 -1 0.18110662 0.565 0
# 3 X2 main 2 2 < 0.265 2814 -1 0.06358597 0.265 0
# 6 X1 main 3 5 < 0.645 2186 -1 0.07645851 0.645 0
列 var
告诉我每个非末端节点的分裂变量,列 left
告诉相关的分裂点。但是,我不知道如何使用这些信息来生成我所需的情节。
当然,如果您有任何替代策略不涉及使用 rpart_splits
随时提出建议。
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(1)
您可以使用(未发表)
parttree
软件包,您可以从GitHub通过:this允许:
顺便说一句,此软件包还包含函数
parttree ,返回与您非常相似的东西
rpart_splits
函数:You could use the (unpublished)
parttree
package, which you can install from GitHub via:This allows:
Incidentally, this package also contains the function
parttree
, which returns something very similar to yourrpart_splits
function: