Skip to content

Commit edd3bcb

Browse files
committed
Rename binomial loss to bernoulli loss
1 parent 2942853 commit edd3bcb

File tree

3 files changed

+8
-8
lines changed

3 files changed

+8
-8
lines changed

core/src/main/scala/org/apache/spark/ml/boosting/GBMLoss.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ case object ExponentialLoss
290290

291291
}
292292

293-
case object BinomialLoss extends GBMClassificationLoss with GBMScalarLoss with HasScalarHessian {
293+
case object BernoulliLoss extends GBMClassificationLoss with GBMScalarLoss with HasScalarHessian {
294294

295295
override def dim: Int = 1
296296

core/src/main/scala/org/apache/spark/ml/classification/GBMClassifier.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ private[ml] trait GBMClassifierParams extends GBMParams with HasParallelism {
7676

7777
/**
7878
* Loss function which GBM tries to minimize. (case-insensitive) Supported: "logloss",
79-
* "exponential", "binomial". (default = logloss)
79+
* "exponential", "bernoulli". (default = logloss)
8080
*
8181
* @group param
8282
*/
@@ -100,7 +100,7 @@ private[ml] trait GBMClassifierParams extends GBMParams with HasParallelism {
100100
private[ml] object GBMClassifierParams {
101101

102102
final val supportedLossTypes: Array[String] =
103-
Array("logloss", "exponential", "binomial").map(_.toLowerCase(Locale.ROOT))
103+
Array("logloss", "exponential", "bernoulli").map(_.toLowerCase(Locale.ROOT))
104104

105105
final val supportedInitStrategy: Array[String] =
106106
Array("uniform", "prior").map(_.toLowerCase(Locale.ROOT))
@@ -109,7 +109,7 @@ private[ml] object GBMClassifierParams {
109109
loss match {
110110
case "logloss" => LogLoss(numClasses)
111111
case "exponential" => ExponentialLoss
112-
case "binomial" => BinomialLoss
112+
case "bernoulli" => BernoulliLoss
113113
case _ => throw new RuntimeException(s"GBMClassifier was given bad loss type: $loss")
114114
}
115115

core/src/test/scala/org/apache/spark/ml/classification/GBMClassifierSuite.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ class GBMClassifierSuite extends AnyFunSuite with BeforeAndAfterAll {
8686

8787
}
8888

89-
test("gbm exponential and binomial binary classification is better than baselines") {
89+
test("gbm exponential and bernoulli binary classification is better than baselines") {
9090

9191
val data =
9292
spark.read
@@ -110,7 +110,7 @@ class GBMClassifierSuite extends AnyFunSuite with BeforeAndAfterAll {
110110
.setBaseLearner(dtr)
111111
.setNumBaseLearners(numBaseLearners)
112112
.setLearningRate(1.0)
113-
.setLoss("binomial")
113+
.setLoss("bernoulli")
114114
.setUpdates("newton")
115115
val dtc = new DecisionTreeClassifier()
116116
.setMaxDepth(maxDepth)
@@ -204,15 +204,15 @@ class GBMClassifierSuite extends AnyFunSuite with BeforeAndAfterAll {
204204
val gbmrWithVal = new GBMClassifier()
205205
.setBaseLearner(dtr)
206206
.setNumBaseLearners(10)
207-
.setLoss("binomial")
207+
.setLoss("bernoulli")
208208
.setUpdates("gradient")
209209
.setValidationIndicatorCol("validation")
210210
.setNumRounds(1)
211211

212212
val gbmrNoVal = new GBMClassifier()
213213
.setBaseLearner(dtr)
214214
.setNumBaseLearners(10)
215-
.setLoss("binomial")
215+
.setLoss("bernoulli")
216216
.setUpdates("gradient")
217217

218218
val mce = new MulticlassClassificationEvaluator().setMetricName("logLoss")

0 commit comments

Comments
 (0)