@@ -111,7 +111,7 @@ def configuration() -> list:
111111 [{"instance_count" : 1 , "py_version" : "py37" }, {"instance_count" : 2 , "py_version" : "py39" }],
112112)
113113def test_sagemaker_pyspark_multinode (
114- role , image_uri , configuration , sagemaker_session , region , sagemaker_client , config
114+ role , image_uri , configuration , sagemaker_session , region , sagemaker_client , config , instance_type
115115):
116116 instance_count = config ["instance_count" ]
117117 python_version = config ["py_version" ]
@@ -122,7 +122,7 @@ def test_sagemaker_pyspark_multinode(
122122 image_uri = image_uri ,
123123 role = role ,
124124 instance_count = instance_count ,
125- instance_type = "ml.c5.xlarge" ,
125+ instance_type = instance_type ,
126126 max_runtime_in_seconds = 1200 ,
127127 sagemaker_session = sagemaker_session ,
128128 )
@@ -193,14 +193,14 @@ def test_sagemaker_pyspark_multinode(
193193# TODO: similar integ test case for SSE-KMS. This would require test infrastructure bootstrapping a KMS key.
194194# Currently, Spark jobs can read data encrypted with SSE-KMS (assuming the execution role has permission),
195195# however our Hadoop version (2.8.5) does not support writing data with SSE-KMS (enabled in version 3.0.0).
196- def test_sagemaker_pyspark_sse_s3 (role , image_uri , sagemaker_session , region , sagemaker_client ):
196+ def test_sagemaker_pyspark_sse_s3 (role , image_uri , sagemaker_session , region , sagemaker_client , instance_type ):
197197 """Test that Spark container can read and write S3 data encrypted with SSE-S3 (default AES256 encryption)"""
198198 spark = PySparkProcessor (
199199 base_job_name = "sm-spark-py" ,
200200 image_uri = image_uri ,
201201 role = role ,
202202 instance_count = 2 ,
203- instance_type = "ml.c5.xlarge" ,
203+ instance_type = instance_type ,
204204 max_runtime_in_seconds = 1200 ,
205205 sagemaker_session = sagemaker_session ,
206206 )
@@ -237,14 +237,14 @@ def test_sagemaker_pyspark_sse_s3(role, image_uri, sagemaker_session, region, sa
237237
238238
239239def test_sagemaker_pyspark_sse_kms_s3 (
240- role , image_uri , sagemaker_session , region , sagemaker_client , account_id , partition
240+ role , image_uri , sagemaker_session , region , sagemaker_client , account_id , partition , instance_type
241241):
242242 spark = PySparkProcessor (
243243 base_job_name = "sm-spark-py" ,
244244 image_uri = image_uri ,
245245 role = role ,
246246 instance_count = 2 ,
247- instance_type = "ml.c5.xlarge" ,
247+ instance_type = instance_type ,
248248 max_runtime_in_seconds = 1200 ,
249249 sagemaker_session = sagemaker_session ,
250250 )
@@ -301,14 +301,16 @@ def test_sagemaker_pyspark_sse_kms_s3(
301301 assert object_metadata ["SSEKMSKeyId" ] == f"arn:{ partition } :kms:{ region } :{ account_id } :key/{ kms_key_id } "
302302
303303
304- def test_sagemaker_scala_jar_multinode (role , image_uri , configuration , sagemaker_session , sagemaker_client ):
304+ def test_sagemaker_scala_jar_multinode (
305+ role , image_uri , configuration , sagemaker_session , sagemaker_client , instance_type
306+ ):
305307 """Test SparkJarProcessor using Scala application jar with external runtime dependency jars staged by SDK"""
306308 spark = SparkJarProcessor (
307309 base_job_name = "sm-spark-scala" ,
308310 image_uri = image_uri ,
309311 role = role ,
310312 instance_count = 2 ,
311- instance_type = "ml.c5.xlarge" ,
313+ instance_type = instance_type ,
312314 max_runtime_in_seconds = 1200 ,
313315 sagemaker_session = sagemaker_session ,
314316 )
@@ -346,7 +348,14 @@ def test_sagemaker_scala_jar_multinode(role, image_uri, configuration, sagemaker
346348
347349
348350def test_sagemaker_feature_store_ingestion_multinode (
349- sagemaker_session , sagemaker_client , spark_version , framework_version , image_uri , role , is_feature_store_available
351+ sagemaker_session ,
352+ sagemaker_client ,
353+ spark_version ,
354+ framework_version ,
355+ image_uri ,
356+ role ,
357+ is_feature_store_available ,
358+ instance_type ,
350359):
351360 """Test FeatureStore use cases by ingesting data to feature group."""
352361
@@ -359,7 +368,7 @@ def test_sagemaker_feature_store_ingestion_multinode(
359368 image_uri = image_uri ,
360369 role = role ,
361370 instance_count = 2 ,
362- instance_type = "ml.c5.xlarge" ,
371+ instance_type = instance_type ,
363372 max_runtime_in_seconds = 1200 ,
364373 sagemaker_session = sagemaker_session ,
365374 )
@@ -383,15 +392,17 @@ def test_sagemaker_feature_store_ingestion_multinode(
383392 raise RuntimeError ("Feature store Spark job stopped unexpectedly" )
384393
385394
386- def test_sagemaker_java_jar_multinode (tag , role , image_uri , configuration , sagemaker_session , sagemaker_client ):
395+ def test_sagemaker_java_jar_multinode (
396+ tag , role , image_uri , configuration , sagemaker_session , sagemaker_client , instance_type
397+ ):
387398 """Test SparkJarProcessor using Java application jar"""
388399 spark = SparkJarProcessor (
389400 base_job_name = "sm-spark-java" ,
390401 framework_version = tag ,
391402 image_uri = image_uri ,
392403 role = role ,
393404 instance_count = 2 ,
394- instance_type = "ml.c5.xlarge" ,
405+ instance_type = instance_type ,
395406 max_runtime_in_seconds = 1200 ,
396407 sagemaker_session = sagemaker_session ,
397408 )
0 commit comments