如何从rpart软件包中绘制递归分区

发布于 2025-02-06 13:59:46 字数 2560 浏览 2 评论 0 原文

我想绘制由递归二进制拆分构建的二维协变量空间的分区。更确切地说,我想编写一个复制以下图的函数(从,PAG):

​购物车算法)。我要实现的是一个函数,该函数将输出 rpart 函数并生成此类图。

它遵循一些示例代码:

## Generating data.
set.seed(1975)

n <- 5000
p <- 2

X <- matrix(sample(seq(0, 1, by = 0.01), n * p, replace = TRUE), ncol = p)
Y <- X[, 1] + 2 * X[, 2] + rnorm(n)

## Building tree.
tree <- rpart(Y ~ ., data = data.frame(Y, X), method = "anova", control = rpart.control(cp = 0, maxdepth = 2))

rpart_splits <- function(fit, digits = getOption("digits")) {
  splits <- fit$splits
  if (!is.null(splits)) {
    ff <- fit$frame
    is.leaf <- ff$var == "<leaf>"
    n <- nrow(splits)
    nn <- ff$ncompete + ff$nsurrogate + !is.leaf
    ix <- cumsum(c(1L, nn))
    ix_prim <- unlist(mapply(ix, ix + c(ff$ncompete, 0), FUN = seq, SIMPLIFY = F))
    type <- rep.int("surrogate", n)
    type[ix_prim[ix_prim <= n]] <- "primary"
    type[ix[ix <= n]] <- "main"
    left <- character(nrow(splits))
    side <- splits[, 2L]
    for (i in seq_along(left)) {
      left[i] <- if (side[i] == -1L)
                   paste("<", format(signif(splits[i, 4L], digits)))
                 else if (side[i] == 1L)
                   paste(">=", format(signif(splits[i, 4L], digits)))
                 else {
                   catside <- fit$csplit[splits[i, 4L], 1:side[i]]
                   paste(c("L", "-", "R")[catside], collapse = "", sep = "")
                 }
    }
    cbind(data.frame(var = rownames(splits),
                     type = type,
                     node = rep(as.integer(row.names(ff)), times = nn),
                     ix = rep(seq_len(nrow(ff)), nn),
                     left = left),
          as.data.frame(splits, row.names = F))
  }
}

使用此函数,我能够恢复所有分裂变量和点:

splits <- rpart_splits(tree)[rpart_splits(tree)$type == "main", ]
splits

#   var type node ix    left count ncat    improve index adj
# 1  X2 main    1  1 < 0.565  5000   -1 0.18110662 0.565   0
# 3  X2 main    2  2 < 0.265  2814   -1 0.06358597 0.265   0
# 6  X1 main    3  5 < 0.645  2186   -1 0.07645851 0.645   0

var 告诉我每个非末端节点的分裂变量,列 left 告诉相关的分裂点。但是,我不知道如何使用这些信息来生成我所需的情节。

当然,如果您有任何替代策略不涉及使用 rpart_splits 随时提出建议。

I want to plot a partition of a two-dimensional covariate space constructed by recursive binary splitting. To be more precise, I would like to write a function that replicates the following graph (taken from Elements of Statistical Learning, pag. 306):

enter image description here

Displayed above is a two-dimensional covariate space and a partition obtained by recursive binary splitting the space using axis-aligned splits (what is also called a CART algorithm). What I want to implement is a function that takes the output of the rpart function and generates such plot.

It follows some example code:

## Generating data.
set.seed(1975)

n <- 5000
p <- 2

X <- matrix(sample(seq(0, 1, by = 0.01), n * p, replace = TRUE), ncol = p)
Y <- X[, 1] + 2 * X[, 2] + rnorm(n)

## Building tree.
tree <- rpart(Y ~ ., data = data.frame(Y, X), method = "anova", control = rpart.control(cp = 0, maxdepth = 2))

Navigating SO I found this function:

rpart_splits <- function(fit, digits = getOption("digits")) {
  splits <- fit$splits
  if (!is.null(splits)) {
    ff <- fit$frame
    is.leaf <- ff$var == "<leaf>"
    n <- nrow(splits)
    nn <- ff$ncompete + ff$nsurrogate + !is.leaf
    ix <- cumsum(c(1L, nn))
    ix_prim <- unlist(mapply(ix, ix + c(ff$ncompete, 0), FUN = seq, SIMPLIFY = F))
    type <- rep.int("surrogate", n)
    type[ix_prim[ix_prim <= n]] <- "primary"
    type[ix[ix <= n]] <- "main"
    left <- character(nrow(splits))
    side <- splits[, 2L]
    for (i in seq_along(left)) {
      left[i] <- if (side[i] == -1L)
                   paste("<", format(signif(splits[i, 4L], digits)))
                 else if (side[i] == 1L)
                   paste(">=", format(signif(splits[i, 4L], digits)))
                 else {
                   catside <- fit$csplit[splits[i, 4L], 1:side[i]]
                   paste(c("L", "-", "R")[catside], collapse = "", sep = "")
                 }
    }
    cbind(data.frame(var = rownames(splits),
                     type = type,
                     node = rep(as.integer(row.names(ff)), times = nn),
                     ix = rep(seq_len(nrow(ff)), nn),
                     left = left),
          as.data.frame(splits, row.names = F))
  }
}

Using this function, I am able to recover all the splitting variables and points:

splits <- rpart_splits(tree)[rpart_splits(tree)$type == "main", ]
splits

#   var type node ix    left count ncat    improve index adj
# 1  X2 main    1  1 < 0.565  5000   -1 0.18110662 0.565   0
# 3  X2 main    2  2 < 0.265  2814   -1 0.06358597 0.265   0
# 6  X1 main    3  5 < 0.645  2186   -1 0.07645851 0.645   0

The column var tells me the splitting variables for each non-terminal node, and the column left tells the associated splitting points. However, I do not know how to use this information to produce my desired plots.

Of course if you have any alternative strategy that do not involve the use of rpart_splits feel free to suggest it.

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

驱逐舰岛风号 2025-02-13 13:59:46

您可以使用(未​​发表) parttree 软件包,您可以从GitHub通过:

remotes::install_github("grantmcdermott/parttree")

this允许:

library(parttree)

ggplot() +
  geom_parttree(data = tree, aes(fill = path)) +
  coord_cartesian(xlim = c(0, 1), ylim = c(0, 1)) +
  scale_fill_brewer(palette = "Pastel1", name = "Partitions") +
  theme_bw(base_size = 16) +
  labs(x = "X2", y = "X1")

“在此处输入映像”

顺便说一句,此软件包还包含函数 parttree ,返回与您非常相似的东西
rpart_splits 函数:

parttree(tree)
  node         Y                        path  xmin  xmax  ymin  ymax
1    4 0.7556079   X2 < 0.565 --> X2 < 0.265  -Inf 0.265  -Inf   Inf
2    5 1.3087679  X2 < 0.565 --> X2 >= 0.265 0.265 0.565  -Inf   Inf
3    6 1.8681143  X2 >= 0.565 --> X1 < 0.645 0.565   Inf  -Inf 0.645
4    7 2.4993361 X2 >= 0.565 --> X1 >= 0.645 0.565   Inf 0.645   Inf

You could use the (unpublished) parttree package, which you can install from GitHub via:

remotes::install_github("grantmcdermott/parttree")

This allows:

library(parttree)

ggplot() +
  geom_parttree(data = tree, aes(fill = path)) +
  coord_cartesian(xlim = c(0, 1), ylim = c(0, 1)) +
  scale_fill_brewer(palette = "Pastel1", name = "Partitions") +
  theme_bw(base_size = 16) +
  labs(x = "X2", y = "X1")

enter image description here

Incidentally, this package also contains the function parttree, which returns something very similar to your
rpart_splits function:

parttree(tree)
  node         Y                        path  xmin  xmax  ymin  ymax
1    4 0.7556079   X2 < 0.565 --> X2 < 0.265  -Inf 0.265  -Inf   Inf
2    5 1.3087679  X2 < 0.565 --> X2 >= 0.265 0.265 0.565  -Inf   Inf
3    6 1.8681143  X2 >= 0.565 --> X1 < 0.645 0.565   Inf  -Inf 0.645
4    7 2.4993361 X2 >= 0.565 --> X1 >= 0.645 0.565   Inf 0.645   Inf
~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文