public class GradientBoostedTrees extends Object implements scala.Serializable, Logging
Stochastic Gradient Boosting
for regression and binary classification.
The implementation is based upon: J.H. Friedman. "Stochastic Gradient Boosting." 1999.
Notes on Gradient Boosting vs. TreeBoost: - This implementation is for Stochastic Gradient Boosting, not for TreeBoost. - Both algorithms learn tree ensembles by minimizing loss functions. - TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes based on the loss function, whereas the original gradient boosting method does not. - When the loss is SquaredError, these methods give the same result, but they could differ for other loss functions.
Constructor and Description |
---|
GradientBoostedTrees(BoostingStrategy boostingStrategy) |
Modifier and Type | Method and Description |
---|---|
GradientBoostedTreesModel |
run(JavaRDD<LabeledPoint> input)
Java-friendly API for
org.apache.spark.mllib.tree.GradientBoostedTrees!#run . |
GradientBoostedTreesModel |
run(RDD<LabeledPoint> input)
Method to train a gradient boosting model
|
static GradientBoostedTreesModel |
train(JavaRDD<LabeledPoint> input,
BoostingStrategy boostingStrategy)
Java-friendly API for
GradientBoostedTrees$.train(org.apache.spark.rdd.RDD<org.apache.spark.mllib.regression.LabeledPoint>, org.apache.spark.mllib.tree.configuration.BoostingStrategy) |
static GradientBoostedTreesModel |
train(RDD<LabeledPoint> input,
BoostingStrategy boostingStrategy)
Method to train a gradient boosting model.
|
equals, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
initializeIfNecessary, initializeLogging, isTraceEnabled, log_, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarning
public GradientBoostedTrees(BoostingStrategy boostingStrategy)
public static GradientBoostedTreesModel train(RDD<LabeledPoint> input, BoostingStrategy boostingStrategy)
input
- Training dataset: RDD of LabeledPoint
.
For classification, labels should take values {0, 1, ..., numClasses-1}.
For regression, labels are real numbers.boostingStrategy
- Configuration options for the boosting algorithm.public static GradientBoostedTreesModel train(JavaRDD<LabeledPoint> input, BoostingStrategy boostingStrategy)
GradientBoostedTrees$.train(org.apache.spark.rdd.RDD<org.apache.spark.mllib.regression.LabeledPoint>, org.apache.spark.mllib.tree.configuration.BoostingStrategy)
public GradientBoostedTreesModel run(RDD<LabeledPoint> input)
input
- Training dataset: RDD of LabeledPoint
.public GradientBoostedTreesModel run(JavaRDD<LabeledPoint> input)
org.apache.spark.mllib.tree.GradientBoostedTrees!#run
.