Skip to content

Commit ebe6733

Browse files
Fix seed parameter for XGBoost (#41)
1 parent 4dd7f0e commit ebe6733

File tree

4 files changed

+7
-7
lines changed

4 files changed

+7
-7
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ CHANGELOG
66
========
77

88
* pyspark: SageMakerModel: Fix bugs in creating model from training job, s3 file and endpoint
9-
9+
* spark/pyspark: XGBoostSageMakerEstimator: Fix seed hyperparameter to use correct type (Int)
1010

1111
1.0.4
1212
=====

sagemaker-pyspark-sdk/src/sagemaker_pyspark/algorithms/XGBoostSageMakerEstimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ class XGBoostSageMakerEstimator(SageMakerEstimatorBase):
331331
seed = Param(
332332
Params._dummy(), "seed",
333333
"Random number seed",
334-
typeConverter=TypeConverters.toFloat)
334+
typeConverter=TypeConverters.toInt)
335335

336336
num_round = Param(
337337
Params._dummy(), "num_round",

sagemaker-spark-sdk/src/main/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/XGBoostSageMakerEstimator.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -379,8 +379,8 @@ private[algorithms] trait XGBoostParams extends Params {
379379
/** Random number seed.
380380
* Default = 0
381381
*/
382-
val seed: DoubleParam = new DoubleParam(this, "seed", "Random number seed.")
383-
def getSeed: Double = $(seed)
382+
val seed: IntParam = new IntParam(this, "seed", "Random number seed.")
383+
def getSeed: Int = $(seed)
384384

385385
/**
386386
* Number of rounds for gradient boosting. Must be >= 1. Required.
@@ -614,7 +614,7 @@ class XGBoostSageMakerEstimator(
614614

615615
def setEvalMetric(value: String) : this.type = set(evalMetric, value)
616616

617-
def setSeed(value: Double) : this.type = set(seed, value)
617+
def setSeed(value: Int) : this.type = set(seed, value)
618618

619619
def setNumRound(value: Int) : this.type = set(numRound, value)
620620

sagemaker-spark-sdk/src/test/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/XGBoostSageMakerEstimatorTests.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ class XGBoostSageMakerEstimatorTests extends FlatSpec with Matchers with Mockito
301301

302302

303303
it should "setSeed" in {
304-
val seed = 10.0
304+
val seed = 10
305305
estimator.setSeed(seed)
306306
assert(seed == estimator.getSeed)
307307
}
@@ -381,7 +381,7 @@ class XGBoostSageMakerEstimatorTests extends FlatSpec with Matchers with Mockito
381381
"objective" -> "reg:logistic",
382382
"base_score" -> "0.5",
383383
"eval_metric" -> "mae",
384-
"seed" -> "0.0"
384+
"seed" -> "0"
385385
)
386386

387387
assert(hyperParamMap.asJava == estimator.makeHyperParameters())

0 commit comments

Comments
 (0)