Skip to content

Commit f5615f8

Browse files
authored
feature: add support for cn-north-1 and cn-northwest-1 (#113)
1 parent 2ee3217 commit f5615f8

File tree

6 files changed

+92
-3
lines changed

6 files changed

+92
-3
lines changed

sagemaker-spark-sdk/src/main/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/ImageURIProvider.scala

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,24 @@ import com.amazonaws.regions.Regions
1919

2020
private[algorithms] object SageMakerImageURIProvider {
2121

22+
def isChinaRegion(region: String): Boolean = {
23+
val chinaRegions = Set(
24+
Regions.CN_NORTH_1.getName,
25+
Regions.CN_NORTHWEST_1.getName
26+
)
27+
chinaRegions.contains(region)
28+
}
29+
2230
def getImage(region: String, regionAccountMap: Map[String, String],
2331
algorithmName: String, algorithmTag: String): String = {
2432
val account = regionAccountMap.get(region)
2533
account match {
2634
case None => throw new RuntimeException(s"The region $region is not supported." +
2735
s"Supported Regions: ${regionAccountMap.keys.mkString(", ")}")
28-
case _ => s"${account.get}.dkr.ecr.${region}.amazonaws.com/${algorithmName}:${algorithmTag}"
36+
case _ if isChinaRegion(region) =>
37+
s"${account.get}.dkr.ecr.${region}.amazonaws.com.cn/${algorithmName}:${algorithmTag}"
38+
case _ =>
39+
s"${account.get}.dkr.ecr.${region}.amazonaws.com/${algorithmName}:${algorithmTag}"
2940
}
3041
}
3142
}
@@ -52,7 +63,9 @@ private[algorithms] object SagerMakerRegionAccountMaps {
5263
Regions.EU_NORTH_1.getName -> "669576153137",
5364
Regions.EU_WEST_3.getName -> "749696950732",
5465
Regions.EU_WEST_3.getName -> "749696950732",
55-
Regions.ME_SOUTH_1.getName -> "249704162688"
66+
Regions.ME_SOUTH_1.getName -> "249704162688",
67+
Regions.CN_NORTH_1.getName -> "390948362332",
68+
Regions.CN_NORTHWEST_1.getName -> "387376663083"
5669
)
5770

5871
// For LDA
@@ -94,7 +107,9 @@ private[algorithms] object SagerMakerRegionAccountMaps {
94107
Regions.EU_NORTH_1.getName -> "669576153137",
95108
Regions.EU_WEST_3.getName -> "749696950732",
96109
Regions.EU_WEST_3.getName -> "749696950732",
97-
Regions.ME_SOUTH_1.getName -> "249704162688"
110+
Regions.ME_SOUTH_1.getName -> "249704162688",
111+
Regions.CN_NORTH_1.getName -> "390948362332",
112+
Regions.CN_NORTHWEST_1.getName -> "387376663083"
98113
)
99114
}
100115

sagemaker-spark-sdk/src/test/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/FactorizationMachinesSageMakerEstimatorTests.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,16 @@ class FactorizationMachinesSageMakerEstimatorTests extends FlatSpec with Mockito
151151
createFactorizationMachinesBinaryClassifier(region = Regions.ME_SOUTH_1.getName)
152152
assert(estimatorMESouth1.trainingImage ==
153153
"249704162688.dkr.ecr.me-south-1.amazonaws.com/factorization-machines:1")
154+
155+
val estimatorCNNorth1 =
156+
createFactorizationMachinesBinaryClassifier(region = Regions.CN_NORTH_1.getName)
157+
assert(estimatorCNNorth1.trainingImage ==
158+
"390948362332.dkr.ecr.cn-north-1.amazonaws.com.cn/factorization-machines:1")
159+
160+
val estimatorCNNorthWest1 =
161+
createFactorizationMachinesBinaryClassifier(region = Regions.CN_NORTHWEST_1.getName)
162+
assert(estimatorCNNorthWest1.trainingImage ==
163+
"387376663083.dkr.ecr.cn-northwest-1.amazonaws.com.cn/factorization-machines:1")
154164
}
155165

156166
it should "use the correct defaults for regressor" in {
@@ -253,6 +263,16 @@ class FactorizationMachinesSageMakerEstimatorTests extends FlatSpec with Mockito
253263
createFactorizationMachinesRegressor(region = Regions.ME_SOUTH_1.getName)
254264
assert(estimatorMESouth1.trainingImage ==
255265
"249704162688.dkr.ecr.me-south-1.amazonaws.com/factorization-machines:1")
266+
267+
val estimatorCNNorth1 =
268+
createFactorizationMachinesRegressor(region = Regions.CN_NORTH_1.getName)
269+
assert(estimatorCNNorth1.trainingImage ==
270+
"390948362332.dkr.ecr.cn-north-1.amazonaws.com.cn/factorization-machines:1")
271+
272+
val estimatorCNNorthWest1 =
273+
createFactorizationMachinesRegressor(region = Regions.CN_NORTHWEST_1.getName)
274+
assert(estimatorCNNorthWest1.trainingImage ==
275+
"387376663083.dkr.ecr.cn-northwest-1.amazonaws.com.cn/factorization-machines:1")
256276
}
257277

258278
it should "setFeatureDim" in {

sagemaker-spark-sdk/src/test/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/KMeansSageMakerEstimatorTests.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,14 @@ class KMeansSageMakerEstimatorTests extends FlatSpec with Matchers with MockitoS
128128
val estimatorMESouth1 = createKMeansEstimator(region = Regions.ME_SOUTH_1.getName)
129129
assert(estimatorMESouth1.trainingImage ==
130130
"249704162688.dkr.ecr.me-south-1.amazonaws.com/kmeans:1")
131+
132+
val estimatorCNNorth1 = createKMeansEstimator(region = Regions.CN_NORTH_1.getName)
133+
assert(estimatorCNNorth1.trainingImage ==
134+
"390948362332.dkr.ecr.cn-north-1.amazonaws.com.cn/kmeans:1")
135+
136+
val estimatorCNNorthWest1 = createKMeansEstimator(region = Regions.CN_NORTHWEST_1.getName)
137+
assert(estimatorCNNorthWest1.trainingImage ==
138+
"387376663083.dkr.ecr.cn-northwest-1.amazonaws.com.cn/kmeans:1")
131139
}
132140

133141
it should "setK" in {

sagemaker-spark-sdk/src/test/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/LinearLearnerSageMakerEstimatorTests.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,16 @@ class LinearLearnerSageMakerEstimatorTests extends FlatSpec with MockitoSugar {
156156
createLinearLearnerBinaryClassifier(region = Regions.ME_SOUTH_1.getName)
157157
assert(estimatorMESouth1.trainingImage ==
158158
"249704162688.dkr.ecr.me-south-1.amazonaws.com/linear-learner:1")
159+
160+
val estimatorCNNorth1 =
161+
createLinearLearnerBinaryClassifier(region = Regions.CN_NORTH_1.getName)
162+
assert(estimatorCNNorth1.trainingImage ==
163+
"390948362332.dkr.ecr.cn-north-1.amazonaws.com.cn/linear-learner:1")
164+
165+
val estimatorCNNorthWest1 =
166+
createLinearLearnerBinaryClassifier(region = Regions.CN_NORTHWEST_1.getName)
167+
assert(estimatorCNNorthWest1.trainingImage ==
168+
"387376663083.dkr.ecr.cn-northwest-1.amazonaws.com.cn/linear-learner:1")
159169
}
160170

161171
it should "use the correct defaults for multiclass classifier" in {
@@ -261,6 +271,16 @@ class LinearLearnerSageMakerEstimatorTests extends FlatSpec with MockitoSugar {
261271
createLinearLearnerMultiClassClassifier(region = Regions.ME_SOUTH_1.getName)
262272
assert(estimatorMESouth1.trainingImage ==
263273
"249704162688.dkr.ecr.me-south-1.amazonaws.com/linear-learner:1")
274+
275+
val estimatorCNNorth1 =
276+
createLinearLearnerMultiClassClassifier(region = Regions.CN_NORTH_1.getName)
277+
assert(estimatorCNNorth1.trainingImage ==
278+
"390948362332.dkr.ecr.cn-north-1.amazonaws.com.cn/linear-learner:1")
279+
280+
val estimatorCNNorthWest1 =
281+
createLinearLearnerMultiClassClassifier(region = Regions.CN_NORTHWEST_1.getName)
282+
assert(estimatorCNNorthWest1.trainingImage ==
283+
"387376663083.dkr.ecr.cn-northwest-1.amazonaws.com.cn/linear-learner:1")
264284
}
265285

266286
it should "use the correct defaults for regressor" in {
@@ -362,6 +382,16 @@ class LinearLearnerSageMakerEstimatorTests extends FlatSpec with MockitoSugar {
362382
createLinearLearnerRegressor(region = Regions.ME_SOUTH_1.getName)
363383
assert(estimatorMESouth1.trainingImage ==
364384
"249704162688.dkr.ecr.me-south-1.amazonaws.com/linear-learner:1")
385+
386+
val estimatorCNNorth1 =
387+
createLinearLearnerRegressor(region = Regions.CN_NORTH_1.getName)
388+
assert(estimatorCNNorth1.trainingImage ==
389+
"390948362332.dkr.ecr.cn-north-1.amazonaws.com.cn/linear-learner:1")
390+
391+
val estimatorCNNorthWest1 =
392+
createLinearLearnerRegressor(region = Regions.CN_NORTHWEST_1.getName)
393+
assert(estimatorCNNorthWest1.trainingImage ==
394+
"387376663083.dkr.ecr.cn-northwest-1.amazonaws.com.cn/linear-learner:1")
365395
}
366396

367397
it should "setFeatureDim" in {

sagemaker-spark-sdk/src/test/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/PCASageMakerEstimatorTests.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,14 @@ class PCASageMakerEstimatorTests extends FlatSpec with MockitoSugar {
108108
val estimatorMESouth1 = createPCAEstimator(region = Regions.ME_SOUTH_1.getName)
109109
assert(estimatorMESouth1.trainingImage ==
110110
"249704162688.dkr.ecr.me-south-1.amazonaws.com/pca:1")
111+
112+
val estimatorCNNorth1 = createPCAEstimator(region = Regions.CN_NORTH_1.getName)
113+
assert(estimatorCNNorth1.trainingImage ==
114+
"390948362332.dkr.ecr.cn-north-1.amazonaws.com.cn/pca:1")
115+
116+
val estimatorCNNorthWest1 = createPCAEstimator(region = Regions.CN_NORTHWEST_1.getName)
117+
assert(estimatorCNNorthWest1.trainingImage ==
118+
"387376663083.dkr.ecr.cn-northwest-1.amazonaws.com.cn/pca:1")
111119
}
112120

113121
it should "use the correct defaults" in {

sagemaker-spark-sdk/src/test/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/XGBoostSageMakerEstimatorTests.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,14 @@ class XGBoostSageMakerEstimatorTests extends FlatSpec with Matchers with Mockito
127127
val estimatorMESouth1 = createXGBoostEstimator(region = Regions.ME_SOUTH_1.getName)
128128
assert(estimatorMESouth1.trainingImage ==
129129
"249704162688.dkr.ecr.me-south-1.amazonaws.com/xgboost:1")
130+
131+
val estimatorCNNorth1 = createXGBoostEstimator(region = Regions.CN_NORTH_1.getName)
132+
assert(estimatorCNNorth1.trainingImage ==
133+
"390948362332.dkr.ecr.cn-north-1.amazonaws.com.cn/xgboost:1")
134+
135+
val estimatorCNNorthWest1 = createXGBoostEstimator(region = Regions.CN_NORTHWEST_1.getName)
136+
assert(estimatorCNNorthWest1.trainingImage ==
137+
"387376663083.dkr.ecr.cn-northwest-1.amazonaws.com.cn/xgboost:1")
130138
}
131139

132140
it should "setBooster" in {

0 commit comments

Comments
 (0)