Gradient-Boosted Trees (GBTs) 多棵决策树的组合
GBTs 是多棵决策树的组合。GDTs 通过多次迭代训练决策树来最小化损失函数。和决策树一样、GVTs 处理类别特征,不需要特征扩展就能处理多分类,而且可以知道非线性特征的相互作用。
MLlib GBTs 通过使用,连续特征和分类特征来支持二分类和回归,MLlib 实现 GBTs 使用现有的决策树接口。详情可以参阅决策树相关内容。
注:GBTs 尚不支持多分类。对于多分类问题,请使用决策树、或随机森林。
Basic algorithm
梯度增强迭代训练一个序列的决策树。在每次迭代中,该算法使用当前预测值的集合和真实值比较。重新标记数据集,把更多的重点放在训练那些预测效果比较差的数据集上。因此,在下一次迭代中,决策树将有助于纠正以前的错误。
重新标记实例的具体方法是通过损失函数(下面讨论)。在每一次迭代 GBTs 降低训练数据的误差通过损失函数。
Losses
下面的表格列出来目前 MLlib 支持的 GBTS 损失函数。注意,每一个损失函数只适用于分类或回归而不是两个都适用。 符号说明
- N = number of instances.
- ${y_i}$ = label of instance i.
- $x_i$ = features of instance i.
- $F(x_i)$ = model’s predicted label for instance i.
Loss | Task | Formula | Description |
---|---|---|---|
Log Loss | Classification | $2 \sum{i=1}^{N} \log(1+\exp(-2 yi F(x_i)))$ | Twice binomial negative log likelihood. |
Squared Error | Regression | $\sum{i=1}^{N} (yi - F(x_i))^2$ | Also called L2 loss. Default loss for regression tasks. |
Absolute Error | Regression | $\sum_{i=1}^{N} | yi - F(xi) |
Usage tips
下面我们将详细说明 GBTs 中各种参数的使用,其中忽略了一些决策树参数,因为这些都包括在决策树章节里面。
- loss:See the section above for information on losses and their applicability to tasks (classification vs. regression). Different losses can give significantly different results, depending on the dataset.
- numIterations: This sets the number of trees in the ensemble. Each iteration produces one tree. Increasing this number makes the model more expressive, improving training data accuracy. However, test-time accuracy may suffer if this is too large.
- learningRate: This parameter should not need to be tuned. If the algorithm behavior seems unstable, decreasing this value may improve stability.
- algo: The algorithm or task (classification vs. regression) is set using the tree [Strategy] parameter.
Validation while training
Gradient boosting can overfit when trained with more trees. In order to prevent overfitting, it is useful to validate while training. The method runWithValidation has been provided to make use of this option. It takes a pair of RDD’s as arguments, the first one being the training dataset and the second being the validation dataset.
Gradient boosting can overfit when trained with more trees. In order to prevent overfitting, it is useful to validate while training. The method runWithValidation has been provided to make use of this option. It takes a pair of RDD’s as arguments, the first one being the training dataset and the second being the validation dataset.
Classification
The example below demonstrates how to load a LIBSVM data file, parse it as an RDD of LabeledPoint and then perform classification using Gradient-Boosted Trees with log loss. The test error is calculated to measure the algorithm accuracy.
val PATH = "file:///Users/lzz/work/SparkML/"
import org.apache.spark.mllib.tree.GradientBoostedTrees
import org.apache.spark.mllib.tree.configuration.BoostingStrategy
import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel
import org.apache.spark.mllib.util.MLUtils
// Load and parse the data file.
val data = MLUtils.loadLibSVMFile(sc, PATH + "data/mllib/sample_libsvm_data.txt")
// Split the data into training and test sets (30% held out for testing)
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainingData, testData) = (splits(0), splits(1))
// Train a GradientBoostedTrees model.
// The defaultParams for Classification use LogLoss by default.
val boostingStrategy = BoostingStrategy.defaultParams("Classification")
boostingStrategy.numIterations = 3 // Note: Use more iterations in practice.
boostingStrategy.treeStrategy.numClasses = 2
boostingStrategy.treeStrategy.maxDepth = 5
// Empty categoricalFeaturesInfo indicates all features are continuous.
boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]()
val model = GradientBoostedTrees.train(trainingData, boostingStrategy)
// Evaluate model on test instances and compute test error
val labelAndPreds = testData.map { point =>
val prediction = model.predict(point.features)
(point.label, prediction)
}
val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count()
println("Test Error = " + testErr)
println("Learned classification GBT model:\n" + model.toDebugString)
Test Error = 0.02702702702702703
Learned classification GBT model:
TreeEnsembleModel classifier with 3 trees
Tree 0:
If (feature 434 <= 0.0)
If (feature 99 <= 0.0)
Predict: -1.0
Else (feature 99 > 0.0)
Predict: 1.0
Else (feature 434 > 0.0)
Predict: 1.0
Tree 1:
If (feature 434 <= 0.0)
If (feature 352 <= 246.0)
If (feature 400 <= 9.0)
If (feature 124 <= 0.0)
Predict: -0.4768116880884702
Else (feature 124 > 0.0)
Predict: -0.4768116880884703
Else (feature 400 > 9.0)
Predict: -0.4768116880884703
Else (feature 352 > 246.0)
Predict: 0.4768116880884694
Else (feature 434 > 0.0)
If (feature 467 <= 28.0)
If (feature 518 <= 248.0)
Predict: 0.47681168808847024
Else (feature 518 > 248.0)
Predict: 0.47681168808847024
Else (feature 467 > 28.0)
Predict: 0.4768116880884712
Tree 2:
If (feature 434 <= 0.0)
If (feature 242 <= 0.0)
Predict: 0.4381935810427206
Else (feature 242 > 0.0)
Predict: -0.4381935810427206
Else (feature 434 > 0.0)
If (feature 178 <= 0.0)
If (feature 123 <= 0.0)
Predict: 0.4381935810427206
Else (feature 123 > 0.0)
If (feature 124 <= 252.0)
Predict: 0.4381935810427206
Else (feature 124 > 252.0)
Predict: 0.43819358104272055
Else (feature 178 > 0.0)
If (feature 97 <= 0.0)
Predict: 0.43819358104272044
Else (feature 97 > 0.0)
Predict: 0.43819358104272044
Regression
The example below demonstrates how to load a LIBSVM data file, parse it as an RDD of LabeledPoint and then perform regression using Gradient-Boosted Trees with Squared Error as the loss. The Mean Squared Error (MSE) is computed at the end to evaluate goodness of fit.
import org.apache.spark.mllib.tree.GradientBoostedTrees
import org.apache.spark.mllib.tree.configuration.BoostingStrategy
import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel
import org.apache.spark.mllib.util.MLUtils
// Load and parse the data file.
val data = MLUtils.loadLibSVMFile(sc, PATH+"data/mllib/sample_libsvm_data.txt")
// Split the data into training and test sets (30% held out for testing)
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainingData, testData) = (splits(0), splits(1))
// Train a GradientBoostedTrees model.
// The defaultParams for Regression use SquaredError by default.
val boostingStrategy = BoostingStrategy.defaultParams("Regression")
boostingStrategy.numIterations = 3 // Note: Use more iterations in practice.
boostingStrategy.treeStrategy.maxDepth = 5
// Empty categoricalFeaturesInfo indicates all features are continuous.
boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]()
val model = GradientBoostedTrees.train(trainingData, boostingStrategy)
// Evaluate model on test instances and compute test error
val labelsAndPredictions = testData.map { point =>
val prediction = model.predict(point.features)
(point.label, prediction)
}
val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean()
println("Test Mean Squared Error = " + testMSE)
println("Learned regression GBT model:\n" + model.toDebugString)
Test Mean Squared Error = 0.13333333333333333
Learned regression GBT model:
TreeEnsembleModel regressor with 3 trees
Tree 0:
If (feature 405 <= 0.0)
If (feature 99 <= 0.0)
Predict: 0.0
Else (feature 99 > 0.0)
Predict: 1.0
Else (feature 405 > 0.0)
Predict: 1.0
Tree 1:
Predict: 0.0
Tree 2:
Predict: 0.0
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论