返回介绍

数学基础

统计学习

深度学习

工具

Scala

七、梯度提升树

发布于 2023-07-17 23:38:23 字数 6388 浏览 0 评论 0 收藏 0

7.1 GradientBoostingClassifier

  1. GradientBoostingClassifierGBDT 分类模型,其原型为:

    
    
    xxxxxxxxxx
    class sklearn.ensemble.GradientBoostingClassifier(loss='deviance', learning_rate=0.1, n_estimators=100, subsample=1.0, min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_depth=3, init=None, random_state=None, max_features=None, verbose=0, max_leaf_nodes=None, warm_start=False, presort='auto')
    • loss:一个字符串,指定损失函数。可以为:

      • 'deviance'(默认值):此时损失函数为对数损失函数: $ MathJax-Element-97 $ 。
      • 'exponential':此时使用指数损失函数。
    • n_estimators:一个整数,指定基础决策树的数量(默认为100),值越大越好。

    • learning_rate:一个浮点数,表示学习率,默认为1。它就是下式中的 $ MathJax-Element-98 $ : $ MathJax-Element-99 $ 。

      • 它用于减少每一步的步长,防止步长太大而跨过了极值点。
      • 通常学习率越小,则需要的基础分类器数量会越多,因此在learning_raten_estimators之间会有所折中。
    • max_depth:一个整数或者None,指定了每个基础决策树模型的max_depth参数。

      • 调整该参数可以获得最佳性能。
      • 如果max_leaf_nodes不是None,则忽略本参数。
    • min_samples_split:一个整数,指定了每个基础决策树模型的min_samples_split参数。

    • min_samples_leaf:一个整数,指定了每个基础决策树模型的min_samples_leaf参数。

    • min_weight_fraction_leaf:一个浮点数,指定了每个基础决策树模型的min_weight_fraction_leaf参数。

    • subsample:一个大于 0 小于等于 1.0 的浮点数,指定了提取原始训练集中多大比例的一个子集用于训练基础决策树。

      • 如果 subsample小于1.0,则梯度提升决策树模型就是随机梯度提升决策树。

        此时会减少方差但是提高了偏差。

      • 它会影响n_estimators参数。

    • max_features:一个整数或者浮点数或者字符串或者None,指定了每个基础决策树模型的max_features参数。

      如果 max_features< n_features,则会减少方差但是提高了偏差。

    • max_leaf_nodes:为整数或者None,指定了每个基础决策树模型的max_leaf_nodes参数。

    • init:一个基础分类器对象或者None,该分类器对象用于执行初始的预测。

      如果为None,则使用loss.init_estimator

    • verbose:一个正数。用于开启/关闭迭代中间输出日志功能。

    • warm_start:一个布尔值。用于指定是否继续使用上一次训练的结果。

    • random_state:一个随机数种子。

    • presort:一个布尔值或者'auto'。指定了每个基础决策树模型的presort参数。

  2. 模型属性:

    • feature_importances_:每个特征的重要性。
    • oob_improvement_:给出训练过程中,每增加一个基础决策树,在测试集上损失函数的改善情况(即:损失函数的减少值)。
    • train_score_:给出训练过程中,每增加一个基础决策树,在训练集上的损失函数的值。
    • init:初始预测使用的分类器。
    • estimators_:所有训练过的基础决策树。
  3. 模型方法:

    • fit(X, y[, sample_weight, monitor]):训练模型。

      其中monitor是一个可调用对象,它在当前迭代过程结束时调用。如果它返回True,则训练过程提前终止。

    • predict(X):用模型进行预测,返回预测值。

    • predict_log_proba(X):返回一个数组,数组的元素依次是X预测为各个类别的概率的对数值。

    • predict_proba(X):返回一个数组,数组的元素依次是X预测为各个类别的概率值。

    • score(X,y[,sample_weight]):返回模型的预测性能得分。

    • staged_predict(X):返回一个数组,数组元素依次是:GBDT 在每一轮迭代结束时的预测值。

    • staged_predict_proba(X):返回一个二维数组,数组元素依次是:GBDT 在每一轮迭代结束时,预测X为各个类别的概率值。

    • staged_score(X, y[, sample_weight]):返回一个数组,数组元素依次是:GBDT 在每一轮迭代结束时,该GBDT 的预测性能得分。

7.2 GradientBoostingRegressor

  1. GradientBoostingRegressorGBRT 回归模型,其原型为:

    
    
    xxxxxxxxxx
    class sklearn.ensemble.GradientBoostingRegressor(loss='ls', learning_rate=0.1, n_estimators=100, subsample=1.0, min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_depth=3, init=None, random_state=None, max_features=None, alpha=0.9, verbose=0, max_leaf_nodes=None, warm_start=False, presort='auto')
    • loss:一个字符串,指定损失函数。可以为:

      • 'ls':损失函数为平方损失函数。

      • 'lad':损失函数为绝对值损失函数。

      • 'huber':损失函数为上述两者的结合,通过alpha参数指定比例,该损失函数的定义为:

        $ L_{Huber}=\begin{cases} \frac 12(y-f(x))^{2},& \text{if} \quad |y-f(x)| \le \alpha\\ \alpha|y-f(x)|-\frac 12 \alpha^{2},& \text{else} \end{cases} $

        即误差较小时,采用平方损失;在误差较大时,采用绝对值损失。

      • 'quantile':分位数回归(分位数指得是百分之几),通过alpha参数指定分位数。

    • alpha:一个浮点数,只有当loss='huber'或者loss='quantile'时才有效。

    • n_estimators: 其它参数参考GradientBoostingClassifier

  2. 模型属性:

    • feature_importances_:每个特征的重要性。
    • oob_improvement_:给出训练过程中,每增加一个基础决策树,在测试集上损失函数的改善情况(即:损失函数的减少值)。
    • train_score_ :给出训练过程中,每增加一个基础决策树,在训练集上的损失函数的值。
    • init:初始预测使用的回归器。
    • estimators_:所有训练过的基础决策树。
  3. 模型方法:

    • fit(X, y[, sample_weight, monitor]):训练模型。

      其中monitor是一个可调用对象,它在当前迭代过程结束时调用。如果它返回True,则训练过程提前终止。

    • predict(X):用模型进行预测,返回预测值。

    • score(X,y[,sample_weight]):返回模型的预测性能得分。

    • staged_predict(X):返回一个数组,数组元素依次是:GBRT 在每一轮迭代结束时的预测值。

    • staged_score(X, y[, sample_weight]):返回一个数组,数组元素依次是:GBRT在每一轮迭代结束时,该GBRT的预测性能得分。

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
    我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
    原文