Skip to content

Commit 213e6ff

Browse files
authored
Fixed subspace, more Doc (#31)
* fixed subspace plus doc * format
1 parent d82e7b6 commit 213e6ff

File tree

5 files changed

+77
-11
lines changed

5 files changed

+77
-11
lines changed

core/src/main/scala/org/apache/spark/ml/ensemble/HasSubBag.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import java.util.UUID
1919

2020
import org.apache.spark.SparkException
2121
import org.apache.spark.ml.ensemble.HasSubBag.SubSpace
22+
import org.apache.spark.ml.feature.VectorSlicer
2223
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
2324
import org.apache.spark.ml.param._
2425
import org.apache.spark.ml.param.shared.HasSeed
@@ -112,12 +113,15 @@ private[ml] trait HasSubBag extends Params with HasSeed {
112113
.withColumn(tmpColName, replicate_row(element_at(col(bagColName), index + 1)))
113114
.drop(col(tmpColName))
114115

115-
val slicerUDF = udf { slicer(subspace) }
116+
val tmpSubSpaceColName = "bag$tmp" + UUID.randomUUID().toString
117+
val vs = new VectorSlicer()
118+
.setInputCol(featuresColName)
119+
.setOutputCol(tmpSubSpaceColName)
120+
.setIndices(subspace)
116121

117-
replicated.withColumn(
118-
featuresColName,
119-
slicerUDF(col(featuresColName)),
120-
df.schema(df.schema.fieldIndex(featuresColName)).metadata)
122+
vs.transform(replicated)
123+
.withColumn(featuresColName, col(tmpSubSpaceColName))
124+
.drop(tmpSubSpaceColName)
121125

122126
}
123127

core/src/test/scala/org/apache/spark/ml/classification/GBMClassifierSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class GBMClassifierSuite extends FunSuite with DatasetSuiteBase {
3535
.addGrid(gbmc.validationIndicatorCol, Array("val"))
3636
.addGrid(gbmc.sampleRatio, Array(0.8))
3737
.addGrid(gbmc.replacement, Array(true))
38-
.addGrid(gbmc.subspaceRatio, Array(1.0))
38+
.addGrid(gbmc.subspaceRatio, Array(0.8))
3939
.addGrid(gbmc.optimizedWeights, Array(false,true))
4040
.addGrid(gbmc.loss, Array("divergence"))
4141
.addGrid(dr.maxDepth, Array(10))

core/src/test/scala/org/apache/spark/ml/regression/GBMRegressorSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,14 @@ class GBMRegressorSuite extends FunSuite with DatasetSuiteBase {
3030

3131
time {
3232
val gbmrParamGrid = new ParamGridBuilder()
33-
.addGrid(gmbr.learningRate, Array(1.0))
33+
.addGrid(gmbr.learningRate, Array(0.1))
3434
.addGrid(gmbr.numBaseLearners, Array(30))
3535
.addGrid(gmbr.validationIndicatorCol, Array("val"))
3636
.addGrid(gmbr.tol, Array(1E-3))
3737
.addGrid(gmbr.numRound, Array(8))
3838
.addGrid(gmbr.sampleRatio, Array(0.8))
3939
.addGrid(gmbr.replacement, Array(true))
40-
.addGrid(gmbr.subspaceRatio, Array(1.0))
40+
.addGrid(gmbr.subspaceRatio, Array(0.8))
4141
.addGrid(gmbr.optimizedWeights, Array(false,true))
4242
.addGrid(gmbr.loss, Array("squared"))
4343
.addGrid(gmbr.alpha, Array(0.5))
@@ -73,7 +73,7 @@ class GBMRegressorSuite extends FunSuite with DatasetSuiteBase {
7373

7474
time {
7575
val paramGrid = new ParamGridBuilder()
76-
.addGrid(gbt.stepSize, Array(1.0))
76+
.addGrid(gbt.stepSize, Array(0.1))
7777
.addGrid(gbt.maxDepth, Array(10))
7878
.addGrid(gbt.maxIter, Array(30))
7979
.addGrid(gbt.subsamplingRate, Array(0.8))

docs/boosting.md

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,32 @@ id: boosting
33
title: Boosting
44
---
55

6-
Documentation Coming Soon.
6+
The old Boosting, à la papa, from Freund and Schapire [[1](#references)].
7+
8+
For classification, SAMME (Multi-class AdaBoost) [[2](#references)] from Ji Zhu is implemented.
9+
10+
For regression, R2 (Improving Regressors using Boosting Techniques) [[3](#references)] from H. Drucker has been chosen.
11+
12+
For convenience, a N Round early stop is available.
13+
14+
## Parameters
15+
16+
The parameters available for Boosting are related to early stop and the loss function for weight computation.
17+
18+
```scala
19+
import org.apache.spark.ml.classification.{BoostingClassifier, DecisionTreeClassifier}
20+
21+
new BoostingClassifier()
22+
.setBaseLearner(new DecisionTreeClassifier()) //Base learner used by the meta-estimator.
23+
.setNumBaseLearners(10) //Number of base learners.
24+
.setLoss("exponential") //Loss function used for weight computation.
25+
.setValidationIndicatorCol("val") //Column name that contains true or false for the early stop data set.
26+
.setTol(1E-3) //Tolerance for optimized step size and gain in loss on early stop set.
27+
.setNumRound(8) //Number of rounds to wait for the loss on early stop set to decrease.
28+
```
29+
30+
## References
31+
32+
* [[1](https://pdf.sciencedirectassets.com/272574/1-s2.0-S0022000000X00384/1-s2.0-S002200009791504X/main.pdf?X-Amz-Security-Token=AgoJb3JpZ2luX2VjEKP%2F%2F%2F%2F%2F%2F%2F%2F%2F%2FwEaCXVzLWVhc3QtMSJHMEUCID7xdi4RVOyBg2JyJaKf%2Bo1b0VB2lLHjOH7N4qiYc1zjAiEA1jYOFh6CjJQZmZ3QcaiDrtaD1P9fYiZkPZM6gR%2B%2FKF8q2gMIXBACGgwwNTkwMDM1NDY4NjUiDAVj5twFPKo1W86Ylyq3A1nD9LiuPFB7iWNzcbJyfjY0ZQjoHwoUo4yrPs9kyH3qntJCFhwM8v1I3278TKFu%2BAtZU%2BJP3OxpJeeXYZ5MPe5g8eYKuwDpdT9mubV3aWr2Vw3EjEkHrVBFE1%2B%2B8Ds3dc9mYqcV87AJCns5uL9mQbh3JTFGuuubYMLkQssmVky%2B3SUvOpW%2Bnl5BTq%2FPqaShPUVW7ky1CLk8%2B3INirdGvWsTeU5GZJRiJqpWYpAS9Qa0Km5BkIPDSHKh5u53tTIUXPqBW6P3MXr2k0XLqpFEIi0%2F8BHDaP%2FcI5EsMvsCFZDsZZAlXZn2Vm8MNZUOnRhOC%2BE2Q1R11o3hly2LGDfz74IihRvXDY40kHvfEwNmeK8y9p7j2NTVUeiNvdjdXpByoEJkmJduPiBVpsQ3SMM3Q6dIm%2BNVzJwMJyQioLcKI7kyC%2FvG6hF9z%2FGRAu7K7hRcbdW5XX2pTES9A5AK9LdeGxvhFThiGfODJaCTPwccTn%2Fw2gDP23uETJ0ldmaqRUJo6TB5LqgeoE6Ll0BwWRJSeUUHybTcVfbFmf6S0ItX42eM0%2Fv7qMsIIWU%2FUV6QRLpth7InXxm7KbUwqurI6wU6tAGcV%2FHkqjW5CYxpXREYK2hHWz12ZKxPV13aBjDyEjTvd85BW0VPpmOixpSlBdV67AnrWSBo1Coo0DNkscAwWepWNDTbZfwHaCd6q7pAyb0RvuD4URqwi2WDTahX9bRK%2BNTAA7vpnfSmv0qLqir02wSLYIP%2Fzf%2FXlhAKyb%2BPTaPhY1Y2JGWkqlykiOMsG2oP42c9LKEVDYkn7y%2Bv62TwYYGQylb%2By1xVnCY3cm9rAFgEGO%2BSWk4%3D&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=20190906T114916Z&X-Amz-SignedHeaders=host&X-Amz-Expires=300&X-Amz-Credential=ASIAQ3PHCVTY4UOJFHA3%2F20190906%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Signature=8bfc5bbd29b405a4b4d174cb515bdc2621ac4246c4e30f9a108b143169c0c6c2&hash=d85f9856ba204a5941472cb769535e404f48935a697c8f25a9d241024018527f&host=68042c943591013ac2b2430a89b270f6af2c76d8dfd086a07176afe7c76c2c61&pii=S002200009791504X&tid=spdf-9461bd1f-a6a2-4a1d-9bb1-a6ada11fbdc2&sid=f588c56f21cb124fa58b550402b796a79eb7gxrqb&type=client)] Freund, Y., & Schapire, R. E. (1997). A decision-theoretic generalization of on-line learning and an application to boosting. Journal of computer and system sciences, 55(1), 119-139.
33+
* [[2](https://web.stanford.edu/~hastie/Papers/samme.pdf)] Hastie, T., Rosset, S., Zhu, J., & Zou, H. (2009). Multi-class adaboost. Statistics and its Interface, 2(3), 349-360.
34+
* [[3](https://pdfs.semanticscholar.org/8d49/e2dedb817f2c3330e74b63c5fc86d2399ce3.pdf)] Drucker, H. (1997, July). Improving regressors using boosting techniques. In ICML (Vol. 97, pp. 107-115).

docs/gbm.md

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,38 @@ id: gbm
33
title: GBM
44
---
55

6-
Documentation Coming Soon.
6+
God Jerome H. Friedman enlightened mankind with GBM (Gradient Boosting Machines) in the beginning of the third millennium.
7+
8+
The first of his ten commandments was named : Greedy Function Approximation: A Gradient Boosting Machine [[1](#references)], introducing a meta algorithm that was aimed to do gradient descent in function space. In the end you were kinda doing it in error space, but the heuristic was god sent.
9+
10+
The second commandment was : Stochastic Gradient Boosting [[2](#references)] introducing randomness in each iteration by using SubBags.
11+
12+
PS : It works for multi-class Classification.
13+
PPS : Early stop is implemented with a N Round variant.
14+
15+
## Parameters
16+
17+
The parameters available for GBM are related to the base framework, the stochastic version and early stop.
18+
19+
```scala
20+
import org.apache.spark.ml.classification.{GBMClassifier, DecisionTreeClassifier}
21+
22+
new GBMClassifier()
23+
.setBaseLearner(new DecisionTreeClassifier()) //Base learner used by the meta-estimator.
24+
.setNumBaseLearners(10) //Number of base learners.
25+
.setLearningRate(0.1) //Shrinkage parameter.
26+
.setSampleRatio(0.8) //Ratio sampling of exemples.
27+
.setReplacement(true) //Exemples drawn with replacement or not.
28+
.setSubspaceRatio(0.8) //Ratio sampling of features.
29+
.setOptimizedWeights(true) //Line search the best step size or use 1 instead.
30+
.setLoss("squared") //Loss function used for residuals and optimized step size.
31+
.setAlpha(0.5) //Extra parameter for certain loss functions as quantile or huber.
32+
.setValidationIndicatorCol("val") //Column name that contains true or false for the early stop data set.
33+
.setTol(1E-3) //Tolerance for optimized step size and gain in loss on early stop set.
34+
.setNumRound(8) //Number of rounds to wait for the loss on early stop set to decrease.
35+
```
36+
37+
## References
38+
39+
* [[1](https://statweb.stanford.edu/~jhf/ftp/trebst.pdf)] Friedman, J. H. (2001). Greedy function approximation: a gradient boosting machine. Annals of statistics, 1189-1232.
40+
* [[2](https://astro.temple.edu/~msobel/courses_files/StochasticBoosting(gradient).pdf)] Friedman, J. H. (2002). Stochastic gradient boosting. Computational statistics & data analysis, 38(4), 367-378.

0 commit comments

Comments
 (0)