返回介绍

数学基础

统计学习

深度学习

工具

Scala

五、CART 树

发布于 2023-07-17 23:38:26 字数 15439 浏览 0 评论 0 收藏 0

  1. CART:classfification and regression tree :学习在给定输入随机变量 $ MathJax-Element-454 $ 条件下,输出随机变量 $ MathJax-Element-258 $ 的条件概率分布的模型。

    • 它同样由特征选取、树的生成、剪枝组成。
    • 它既可用于分类,也可用于回归。
  2. CART 假设决策树是二叉树:

    • 内部结点特征的取值为 。其中:左侧分支取 ,右侧分支取
    • 它递归地二分每个特征,将输入空间划分为有限个单元。
  3. CART 树与ID3 决策树和 C4.5 决策树的重要区别:

    • CART 树是二叉树,而后两者是N 叉树。

      由于是二叉树,因此 CART 树的拆分不依赖于特征的取值数量。因此CART 树也就不像ID3 那样倾向于取值数量较多的特征。

    • CART 树的特征可以是离散的,也可以是连续的。

      而后两者的特征是离散的。如果是连续的特征,则需要执行分桶来进行离散化。

      CART 树处理连续特征时,也可以理解为二分桶的离散化。

  4. CART算法分两步:

    • 决策树生成:用训练数据生成尽可能大的决策树。
    • 决策树剪枝:用验证数据基于损失函数最小化的标准对生成的决策树剪枝。

5.1 CART 生成算法

  1. CART生成算法有两个生成准则:

    • CART 回归树:用平方误差最小化准则。
    • CART 分类树:用基尼指数最小化准则。

5.1.1 CART 回归树

5.1.1.1 划分单元和划分点

  1. 一棵CART 回归树对应着输入空间的一个划分,以及在划分单元上的输出值。

    设输出 $ MathJax-Element-258 $ 为连续变量,训练数据集 $ MathJax-Element-259 $ 。

    设已经将输入空间划分为 $ MathJax-Element-260 $ 个单元 $ MathJax-Element-290 $ ,且在每个单元 $ MathJax-Element-266 $ 上有一个固定的输出值 $ MathJax-Element-263 $ 。则CART 回归树模型可以表示为:

    $ f(\mathbf{\vec x})=\sum_{m=1}^{M}c_m I(\mathbf {\vec x} \in R_m) $

    其中 $ MathJax-Element-264 $ 为示性函数。

  2. 如果已知输入空间的单元划分,基于平方误差最小的准则,则CART 回归树在训练数据集上的损失函数为:

    $ \sum_{m=1}^{M}\sum_{\mathbf {\vec x}_i \in R_m}(\tilde y_i-c_m)^{2} $

    根据损失函数最小,则可以求解出每个单元上的最优输出值 $ MathJax-Element-265 $ 为 : $ MathJax-Element-266 $ 上所有输入样本 $ MathJax-Element-267 $ 对应的输出 $ MathJax-Element-268 $ 的平均值。

    即: $ MathJax-Element-196 $ ,其中 $ MathJax-Element-197 $ 表示单元 $ MathJax-Element-266 $ 中的样本数量。

  3. 定义 $ MathJax-Element-266 $ 上样本的方差为 $ MathJax-Element-200 $ ,则有: $ MathJax-Element-201 $ 。则CART 回归树的训练集损失函数重写为: $ MathJax-Element-202 $ ,其中 $ MathJax-Element-255 $ 为训练样本总数。

    定义样本被划分到 $ MathJax-Element-266 $ 中的概率为 $ MathJax-Element-205 $ ,则 $ MathJax-Element-206 $ 。由于 $ MathJax-Element-255 $ 是个常数,因此损失函数重写为:

    $ \sum_{m=1}^M P_m\times \text{Var}_m $

    其物理意义为:经过输入空间的单元划分之后,CART 回归树的方差,通过这个方差来刻画CART 回归树的纯度。

  4. 问题是输入空间的单元划分是未知的。如何对输入空间进行划分?

    设输入为 $ MathJax-Element-425 $ 维: $ MathJax-Element-209 $ 。

    • 选择第 $ MathJax-Element-276 $ 维 $ MathJax-Element-211 $ 和它的取值 $ MathJax-Element-282 $ 作为切分变量和切分点。定义两个区域:
    $ R_1(j,s)=\{\mathbf {\vec x} \mid x_j \le s\}\\ R_2(j,s)=\{\mathbf {\vec x}\mid x_j \gt s\} $
    • 然后寻求最优切分变量 $ MathJax-Element-276 $ 和最优切分点 $ MathJax-Element-282 $ 。即求解:

      $ (j^*,s^*) = \min_{j,s}\left[ \min_{c_1} \sum_{\mathbf{\vec x_i} \in R_1(j,s)}(\tilde y_i-c_1)^{2} +\min_{c_2} \sum_{\mathbf{\vec x_i} \in R_2(j,s)}(\tilde y_i-c_2)^{2}\right] $

      其意义为:

      • 首先假设已知切分变量 $ MathJax-Element-281 $ ,则遍历最优切分点 $ MathJax-Element-282 $ ,则到:

        $ \hat c_1= \frac {1}{N_1} \sum_{\mathbf {\vec x}_i \in R_1(j,s)}\tilde y_i ,\quad \hat c_2= \frac {1}{N_2} \sum_{\mathbf {\vec x}_i \in R_2(j,s)}\tilde y_i $

        其中 $ MathJax-Element-284 $ 和 $ MathJax-Element-285 $ 分别代表区域 $ MathJax-Element-286 $ 和 $ MathJax-Element-287 $ 中的样本数量。

      • 然后遍历所有的特征维度,对每个维度找到最优切分点。从这些(切分维度,最优切分点) 中找到使得损失函数最小的那个。

  5. 依次将输入空间划分为两个区域,然后重复对子区域划分,直到满足停止条件为止。这样的回归树称为最小二乘回归树。

5.1.1.2 生成算法

  1. CART 回归树生成算法:

    • 输入:

      • 训练数据集 $ MathJax-Element-420 $
      • 停止条件
    • 输出: CART回归树 $ MathJax-Element-222 $

    • 步骤:

      • 选择最优切分维度 $ MathJax-Element-276 $ 和切分点 $ MathJax-Element-282 $ 。

        即求解:

      $ (j^*,s^*) = \min_{j,s}\left[ \min_{c_1} \sum_{\mathbf {\vec x}_i \in R_1(j,s)}(\tilde y_i-c_1)^{2} +\min_{c_2} \sum_{\mathbf {\vec x}_i \in R_2(j,s)}(\tilde y_i-c_2)^{2}\right] $
      • 用选定的 $ MathJax-Element-283 $ 划分区域并决定响应的输出值:

        $ R_1(j,s)=\{\mathbf {\vec x} \mid x_j \le s\},\quad R_2(j,s)=\{\mathbf {\vec x}\mid x_j \gt s\}\\ \hat c_1= \frac {1}{N_1} \sum_{\mathbf {\vec x}_i \in R_1(j,s)}\tilde y_i ,\quad \hat c_2= \frac {1}{N_2} \sum_{\mathbf {\vec x}_i \in R_2(j,s)}\tilde y_i $

        其中 $ MathJax-Element-284 $ 和 $ MathJax-Element-285 $ 分别代表区域 $ MathJax-Element-286 $ 和 $ MathJax-Element-287 $ 中的样本数量。

      • 对子区域 $ MathJax-Element-288 $ 递归地切分,直到满足停止条件。

      • 最终将输入空间划分为 $ MathJax-Element-413 $ 个区域 $ MathJax-Element-290 $ ,生成决策树: $ MathJax-Element-291 $ 。

5.1.2 CART 分类树

5.1.2.1 基尼系数

  1. CART 分类树采用基尼指数选择最优特征。

  2. 假设有 $ MathJax-Element-434 $ 个分类,样本属于第 $ MathJax-Element-235 $ 类的概率为 $ MathJax-Element-236 $ 。则概率分布的基尼指数为:

    $ Gini(p)=\sum_{k=1}^{K}p_k(1-p_k)=1-\sum_{k=1}^{K}p_k^{2} $

    基尼指数表示:样本集合中,随机选中一个样本,该样本被分错的概率。基尼指数越小,表示越不容易分错。

    样本被选错概率 = 样本被选中的概率 $ MathJax-Element-237 $ * 样本被分错的概率 $ MathJax-Element-238 $ 。

  3. 对于给定的样本集合 $ MathJax-Element-420 $ ,设属于类 $ MathJax-Element-240 $ 的样本子集为 $ MathJax-Element-244 $ ,则样本集的基尼指数为:

    $ Gini(\mathbb D)=1-\sum_{k=1}^{K}(\frac{N_k}{N})^{2} $

    其中 $ MathJax-Element-255 $ 为样本总数, $ MathJax-Element-243 $ 为子集 $ MathJax-Element-244 $ 的样本数量。

  4. 对于最简单的二项分布,设 $ MathJax-Element-245 $ ,则其基尼系数与熵的图形见下图。

    • 可以看到基尼系数与熵一样,也是度量不确定性的度量。

    • 对于样本集 $ MathJax-Element-420 $ , $ MathJax-Element-247 $ 越小,说明集合中的样本越纯净。

      entropy_and_gini

  5. 若样本集 $ MathJax-Element-248 $ 根据特征 $ MathJax-Element-452 $ 是否小于 $ MathJax-Element-250 $ 而被分为两个子集: $ MathJax-Element-251 $ 和 $ MathJax-Element-252 $ ,其中:

    $ \mathbb D_1=\{(\mathbf{\vec x},y) \in \mathbb D \mid x_A \le a\}\\ \mathbb D_2=\{(\mathbf {\vec x},y) \in \mathbb D \mid x_A \gt a\}=\mathbb D-\mathbb D_1 $

    则在特征 $ MathJax-Element-253 $ 的条件下,集合 $ MathJax-Element-420 $ 的基尼指数为:

    $ Gini(\mathbb D,A:a)=\frac{N_1}{N}Gini(\mathbb D_1)+\frac{N_2}{N}Gini(\mathbb D_2) $

    其中 $ MathJax-Element-255 $ 为样本总数, $ MathJax-Element-256 $ 分别子集 $ MathJax-Element-257 $ 的样本数量。它就是每个子集的基尼系数的加权和,权重是每个子集的大小(以子集占整体集合大小的百分比来表示)。

5.1.2.2 划分单元和划分点

  1. 一棵CART 分类树对应着输入空间的一个划分,以及在划分单元上的输出值。

    设输出 $ MathJax-Element-258 $ 为分类的类别,是离散变量。训练数据集 $ MathJax-Element-259 $ 。

    设已经将输入空间划分为 $ MathJax-Element-260 $ 个单元 $ MathJax-Element-290 $ ,且在每个单元 $ MathJax-Element-266 $ 上有一个固定的输出值 $ MathJax-Element-263 $ 。则CART 分类树模型可以表示为:

    $ f(\mathbf{\vec x})=\sum_{m=1}^{M}c_m I(\mathbf {\vec x} \in R_m) $

    其中 $ MathJax-Element-264 $ 为示性函数。

  2. 如果已知输入空间的单元划分,基于分类误差最小的准则,则CART 分类树在训练数据集上的损失函数为:

    $ \sum_{m=1}^{M}\sum_{\mathbf {\vec x}_i \in R_m}I(\tilde y_i\ne c_m) $

    根据损失函数最小,则可以求解出每个单元上的最优输出值 $ MathJax-Element-265 $ 为 : $ MathJax-Element-266 $ 上所有输入样本 $ MathJax-Element-267 $ 对应的输出 $ MathJax-Element-268 $ 的众数。

    即: $ MathJax-Element-269 $ 。

  3. 问题是输入空间的单元划分是未知的。如何对输入空间进行划分?

    类似CART 回归树,CART 分类树遍历所有可能的维度 $ MathJax-Element-281 $ 和该维度所有可能的取值 $ MathJax-Element-282 $ ,取使得基尼系数最小的那个维度 $ MathJax-Element-281 $ 和切分点 $ MathJax-Element-282 $ 。

    即求解: $ MathJax-Element-274 $ 。

5.1.2.3 生成算法

  1. CART 分类树的生成算法:

    • 输入:

      • 训练数据集 $ MathJax-Element-420 $
      • 停止计算条件
    • 输出: CART 决策树

    • 步骤:

      • 选择最优切分维度 $ MathJax-Element-276 $ 和切分点 $ MathJax-Element-282 $ 。

        即求解: $ MathJax-Element-278 $ 。

        它表示:遍历所有可能的维度 $ MathJax-Element-281 $ 和该维度所有可能的取值 $ MathJax-Element-282 $ ,取使得基尼系数最小的那个维度 $ MathJax-Element-281 $ 和切分点 $ MathJax-Element-282 $ 。

      • 用选定的 $ MathJax-Element-283 $ 划分区域并决定响应的输出值:

        $ R_1(j,s)=\{\mathbf {\vec x} \mid x_j \le s\},\quad R_2(j,s)=\{\mathbf {\vec x}\mid x_j \gt s\}\\ \hat c_1=\arg\max_{c_1} \sum_{\mathbf {\vec x}_i \in R_1}I(c_1 = \tilde y_i),\quad \hat c_2= \arg\max_{c_2} \sum_{\mathbf {\vec x}_i \in R_2}I(c_m = \tilde y_i) $

        其中 $ MathJax-Element-284 $ 和 $ MathJax-Element-285 $ 分别代表区域 $ MathJax-Element-286 $ 和 $ MathJax-Element-287 $ 中的样本数量。

      • 对子区域 $ MathJax-Element-288 $ 递归地切分,直到满足停止条件。

      • 最终将输入空间划分为 $ MathJax-Element-413 $ 个区域 $ MathJax-Element-290 $ ,生成决策树: $ MathJax-Element-291 $ 。

5.1.3 其它讨论

  1. CART 分类树和CART 回归树通常的停止条件为:

    • 结点中样本个数小于预定值,这表示树已经太复杂。
    • 样本集的损失函数或者基尼指数小于预定值,表示结点已经非常纯净。
    • 没有更多的特征可供切分。
  2. 前面讨论的CART 分类树和CART 回归树都假设特征均为连续值。

    • 实际上CART 树的特征可以为离散值,此时切分区域定义为:

      $ R_1(j,s)=\{\mathbf {\vec x} \mid x_j = s\}\\ R_2(j,s)=\{\mathbf {\vec x}\mid x_j \ne s\} $
    • 连续的特征也可以通过分桶来进行离散化,然后当作离散特征来处理。

5.2 CART 剪枝

  1. CART 树的剪枝是从完全生长的CART 树底端减去一些子树,使得CART 树变小(即模型变简单),从而使得它对未知数据有更好的预测能力。

5.2.1 原理

  1. 定义CART 树 $ MathJax-Element-403 $ 的损失函数为( $ MathJax-Element-293 $ ): $ MathJax-Element-294 $ 。其中:

    • $ MathJax-Element-302 $ 为参数是 $ MathJax-Element-408 $ 时树 $ MathJax-Element-403 $ 的整体损失。
    • $ MathJax-Element-298 $ 为树 $ MathJax-Element-403 $ 对训练数据的预测误差。
    • $ MathJax-Element-300 $ 为子树的叶结点个数。
  2. 对固定的 $ MathJax-Element-408 $ ,存在使 $ MathJax-Element-302 $ 最小的子树,令其为 $ MathJax-Element-308 $ 。可以证明 $ MathJax-Element-308 $ 是唯一的。

    • 当 $ MathJax-Element-408 $ 大时, $ MathJax-Element-308 $ 偏小,即叶结点偏少。
    • 当 $ MathJax-Element-408 $ 小时, $ MathJax-Element-308 $ 偏大,即叶结点偏多。
    • 当 $ MathJax-Element-321 $ 时,未剪枝的生成树就是最优的,此时不需要剪枝。
    • 当 $ MathJax-Element-310 $ 时,根结点组成的一个单结点树就是最优的。此时剪枝到极致:只剩下一个结点。
  3. 令从生成树 $ MathJax-Element-389 $ 开始剪枝。对 $ MathJax-Element-389 $ 任意非叶结点 $ MathJax-Element-462 $ ,考虑:需不需要对 $ MathJax-Element-462 $ 进行剪枝?

    • 以 $ MathJax-Element-462 $ 为单结点树:因为此时只有一个叶结点,即为 $ MathJax-Element-462 $ 本身,所以损失函数为: $ MathJax-Element-317 $ 。
    • 以 $ MathJax-Element-462 $ 为根的子树 $ MathJax-Element-337 $ :此时的损失函数为: $ MathJax-Element-320 $ 。
  4. 可以证明:

    • 当 $ MathJax-Element-321 $ 及充分小时,有 $ MathJax-Element-322 $ 。即此时倾向于选择比较复杂的 $ MathJax-Element-337 $ ,因为正则化项的系数 $ MathJax-Element-408 $ 太小。
    • 当 $ MathJax-Element-408 $ 增大到某个值时,有 $ MathJax-Element-326 $ 。
    • 当 $ MathJax-Element-408 $ 再增大时,有 $ MathJax-Element-328 $ 。即此时倾向于选择比较简单的 $ MathJax-Element-462 $ ,因为正则化项的系数 $ MathJax-Element-408 $ 太大。
  5. 令 $ MathJax-Element-331 $ ,此时 $ MathJax-Element-337 $ 与 $ MathJax-Element-462 $ 有相同的损失函数值,但是 $ MathJax-Element-462 $ 的结点更少。

    因此 $ MathJax-Element-462 $ 比 $ MathJax-Element-337 $ 更可取,于是对 $ MathJax-Element-337 $ 进行剪枝。

  6. 对 $ MathJax-Element-389 $ 内部的每一个内部结点 $ MathJax-Element-462 $ ,计算 $ MathJax-Element-340 $ 。

    它表示剪枝后整体损失函数增加的程度(可以为正,可以为负)。则有:

    $ C_\alpha(t)-C_\alpha(T_t)=C(t)+\alpha-C(T_t)-\alpha|T_t|\\ =C(t)-C(T_t)-\alpha(|T_t|-1)\\ =(g(t)-\alpha)(|T_t|-1) $

    因为 $ MathJax-Element-462 $ 是个内部结点,所以 $ MathJax-Element-342 $ ,因此有:

    • $ MathJax-Element-343 $ 时, $ MathJax-Element-344 $ ,表示剪枝后,损失函数增加。
    • $ MathJax-Element-395 $ 时, $ MathJax-Element-346 $ ,表示剪枝后,损失函数不变。
    • $ MathJax-Element-347 $ 时, $ MathJax-Element-348 $ ,表示剪枝后,损失函数减少。
  7. 对 $ MathJax-Element-389 $ 内部的每一个内部结点 $ MathJax-Element-462 $ ,计算最小的 $ MathJax-Element-351 $ :

    $ g*=\min g(t), t\in T_0 \; \text{and $t$ is not a leaf} $

    设 $ MathJax-Element-352 $ 对应的内部结点为 $ MathJax-Element-364 $ ,在 $ MathJax-Element-389 $ 内减去 $ MathJax-Element-355 $ ,得到的子树作为 $ MathJax-Element-373 $ 。

    令 $ MathJax-Element-357 $ ,对于 $ MathJax-Element-358 $ ,有: $ MathJax-Element-359 $ 。

    对于 $ MathJax-Element-360 $ ,有:

    • 对于 $ MathJax-Element-364 $ 剪枝,得到的子树的损失函数一定是减少的。

      它也是所有内部结点剪枝结果中,减少的最多的。因此 $ MathJax-Element-373 $ 是 $ MathJax-Element-363 $ 内的最优子树。

    • 对任意一个非 $ MathJax-Element-364 $ 内部结点的剪枝,得到的子树的损失函数有可能是增加的,也可能是减少的。

      如果损失函数是减少的,它也没有 $ MathJax-Element-373 $ 减少的多。

  8. 如此剪枝下去,直到根结点被剪枝。

    • 此过程中不断产生 $ MathJax-Element-366 $ 的值,产生新区间 $ MathJax-Element-367 $
    • 此过程中不断产生 最优子树 $ MathJax-Element-368 $
    • 其中 $ MathJax-Element-373 $ 是由 $ MathJax-Element-389 $ 产生的、 $ MathJax-Element-371 $ 内的最优子树; $ MathJax-Element-372 $ 是由 $ MathJax-Element-373 $ 产生的、 $ MathJax-Element-374 $ 内的最优子树;...
  9. 上述剪枝的思想就是用递归的方法对树进行剪枝:计算出一个序列 $ MathJax-Element-375 $ ,同时剪枝得到一系列最优子树序列 $ MathJax-Element-376 $ $ MathJax-Element-377 $ 是 $ MathJax-Element-378 $ 时的最优子树。

  10. 上述剪枝的结果只是对于训练集的损失函数较小。

    • 需要用交叉验证的方法在验证集上对子树序列进行测试,挑选中出最优子树。

      交叉验证的本质就是为了挑选超参数 $ MathJax-Element-408 $ 。

    • 验证过程:用独立的验证数据集来测试子树序列 $ MathJax-Element-388 $ 中各子树的平方误差或者基尼指数。

      由于 $ MathJax-Element-381 $ 对应于一个参数序列 $ MathJax-Element-382 $ ,因此当最优子树 $ MathJax-Element-383 $ 确定时,对应的区间 $ MathJax-Element-384 $ 也确定了。

5.2.2 算法

  1. CART剪枝由两步组成:

    • 从生成算法产生的决策树 $ MathJax-Element-389 $ 底端开始不断的剪枝:

      • 每剪枝一次生成一个决策树 $ MathJax-Element-386 $
      • 这一过程直到 $ MathJax-Element-389 $ 的根结点,形成一个子树序列 $ MathJax-Element-388 $ 。
    • 用交叉验证的方法在独立的验证集上对子树序列进行测试,挑选中出最优子树。

  2. CART剪枝算法:

    • 输入:CART算法生成的决策树 $ MathJax-Element-389 $

    • 输出: 最优决策树 $ MathJax-Element-405 $

    • 算法步骤:

      • 初始化: $ MathJax-Element-391 $

      • 自下而上的对各内部结点 $ MathJax-Element-462 $ 计算 : $ MathJax-Element-393 $ 。

      • 自下而上地访问内部结点 $ MathJax-Element-462 $ : 若有 $ MathJax-Element-395 $ ,则进行剪枝,并确定叶结点 $ MathJax-Element-462 $ 的输出,得到树 $ MathJax-Element-403 $ 。

        • 如果为分类树,则叶结点 $ MathJax-Element-462 $ 的输出采取多数表决法:结点 $ MathJax-Element-462 $ 内所有样本的标记的众数。
        • 如果为回归树,则叶结点 $ MathJax-Element-462 $ 的输出为平均法:结点 $ MathJax-Element-462 $ 内所有样本的标记的均值。
      • 令 $ MathJax-Element-402 $ 。

      • 若 $ MathJax-Element-403 $ 不是由根结点单独构成的树,则继续前面的步骤。

      • 采用交叉验证法在子树序列 $ MathJax-Element-404 $ 中选取最优子树 $ MathJax-Element-405 $ 。

  3. CART剪枝算法的优点是:不显式需要指定正则化系数 $ MathJax-Element-408 $ 。

    • CART 剪枝算法自动生成了一系列良好的超参数 $ MathJax-Element-407 $ ,然后利用验证集进行超参数选择。

    • 虽然传统剪枝算法也可以用验证集来进行超参数选择,但是CART 剪枝算法的效率更高。

      因为CART 剪枝算法只需要搜索超参数 $ MathJax-Element-408 $ 的有限数量的区间即可,而传统剪枝算法需要搜索整个数域 $ MathJax-Element-409 $ 。

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
    我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
    原文