Skip to content

Commit d0c10da

Browse files
authored
Add instance_type parameter to test functions (#162)
* Update sagemaker tests to use new instance-type parameter, with a default value of "ml.c5.xlarge"
1 parent 6ba2ff0 commit d0c10da

File tree

5 files changed

+40
-20
lines changed

5 files changed

+40
-20
lines changed

test/integration/conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def pytest_addoption(parser) -> str:
2929
parser.addoption("--spark-version")
3030
parser.addoption("--framework-version")
3131
parser.addoption("--domain", default="amazonaws.com")
32+
parser.addoption("--instance-type", default="ml.c5.xlarge")
3233

3334

3435
@pytest.fixture(scope="session")
@@ -102,6 +103,12 @@ def role(request) -> str:
102103
return request.config.getoption("--role")
103104

104105

106+
@pytest.fixture(scope="session")
107+
def instance_type(request) -> str:
108+
"""Return the SageMaker Procesing instance type to use in tests."""
109+
return request.config.getoption("--instance-type")
110+
111+
105112
@pytest.fixture(scope="session")
106113
def boto_session(region) -> boto3.session.Session:
107114
"""Return a boto session for use in constructing clients in integration tests."""

test/integration/history/test_spark_history_server.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
SPARK_APPLICATION_URL_SUFFIX = "/history/application_1594922484246_0001/1/jobs/"
2626

2727

28-
def test_history_server(tag, framework_version, role, image_uri, sagemaker_session, region):
28+
def test_history_server(tag, framework_version, role, image_uri, sagemaker_session, region, instance_type):
2929
print(
3030
f"PySparkProcessor args: tag={tag}, framework_version={framework_version}, "
3131
f"role={role}, image_uri={image_uri}, region={region}"
@@ -39,7 +39,7 @@ def test_history_server(tag, framework_version, role, image_uri, sagemaker_sessi
3939
image_uri=image_uri,
4040
role=role,
4141
instance_count=1,
42-
instance_type="ml.c5.xlarge",
42+
instance_type=instance_type,
4343
max_runtime_in_seconds=1200,
4444
sagemaker_session=sagemaker_session,
4545
)
@@ -72,14 +72,16 @@ def test_history_server(tag, framework_version, role, image_uri, sagemaker_sessi
7272
spark.terminate_history_server()
7373

7474

75-
def test_history_server_with_expected_failure(tag, framework_version, role, image_uri, sagemaker_session, caplog):
75+
def test_history_server_with_expected_failure(
76+
tag, framework_version, role, image_uri, sagemaker_session, caplog, instance_type
77+
):
7678
spark = PySparkProcessor(
7779
base_job_name="sm-spark",
7880
framework_version=framework_version,
7981
image_uri=image_uri,
8082
role=role,
8183
instance_count=1,
82-
instance_type="ml.c5.xlarge",
84+
instance_type=instance_type,
8385
max_runtime_in_seconds=1200,
8486
sagemaker_session=sagemaker_session,
8587
)

test/integration/prod/test_default_tag.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@
1616
from sagemaker.spark.processing import PySparkProcessor
1717

1818

19-
def test_sagemaker_spark_processor_default_tag(spark_version, role, sagemaker_session, sagemaker_client):
19+
def test_sagemaker_spark_processor_default_tag(spark_version, role, sagemaker_session, sagemaker_client, instance_type):
2020
"""Test that spark processor works with default tag"""
2121
spark = PySparkProcessor(
2222
base_job_name="sm-spark-py",
2323
framework_version=spark_version,
2424
role=role,
2525
instance_count=1,
26-
instance_type="ml.c5.xlarge",
26+
instance_type=instance_type,
2727
max_runtime_in_seconds=1200,
2828
sagemaker_session=sagemaker_session,
2929
)

test/integration/sagemaker/test_sagemaker_spark_errors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@
1313
from sagemaker.spark.processing import PySparkProcessor
1414

1515

16-
def test_spark_app_error(tag, role, image_uri, sagemaker_session):
16+
def test_spark_app_error(tag, role, image_uri, sagemaker_session, instance_type):
1717
"""Submits a PySpark app which is scripted to exit with error code 1"""
1818
spark = PySparkProcessor(
1919
base_job_name="sm-spark-app-error",
2020
framework_version=tag,
2121
image_uri=image_uri,
2222
role=role,
2323
instance_count=1,
24-
instance_type="ml.c5.xlarge",
24+
instance_type=instance_type,
2525
max_runtime_in_seconds=1200,
2626
sagemaker_session=sagemaker_session,
2727
)

test/integration/sagemaker/test_spark.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def configuration() -> list:
111111
[{"instance_count": 1, "py_version": "py37"}, {"instance_count": 2, "py_version": "py39"}],
112112
)
113113
def test_sagemaker_pyspark_multinode(
114-
role, image_uri, configuration, sagemaker_session, region, sagemaker_client, config
114+
role, image_uri, configuration, sagemaker_session, region, sagemaker_client, config, instance_type
115115
):
116116
instance_count = config["instance_count"]
117117
python_version = config["py_version"]
@@ -122,7 +122,7 @@ def test_sagemaker_pyspark_multinode(
122122
image_uri=image_uri,
123123
role=role,
124124
instance_count=instance_count,
125-
instance_type="ml.c5.xlarge",
125+
instance_type=instance_type,
126126
max_runtime_in_seconds=1200,
127127
sagemaker_session=sagemaker_session,
128128
)
@@ -193,14 +193,14 @@ def test_sagemaker_pyspark_multinode(
193193
# TODO: similar integ test case for SSE-KMS. This would require test infrastructure bootstrapping a KMS key.
194194
# Currently, Spark jobs can read data encrypted with SSE-KMS (assuming the execution role has permission),
195195
# however our Hadoop version (2.8.5) does not support writing data with SSE-KMS (enabled in version 3.0.0).
196-
def test_sagemaker_pyspark_sse_s3(role, image_uri, sagemaker_session, region, sagemaker_client):
196+
def test_sagemaker_pyspark_sse_s3(role, image_uri, sagemaker_session, region, sagemaker_client, instance_type):
197197
"""Test that Spark container can read and write S3 data encrypted with SSE-S3 (default AES256 encryption)"""
198198
spark = PySparkProcessor(
199199
base_job_name="sm-spark-py",
200200
image_uri=image_uri,
201201
role=role,
202202
instance_count=2,
203-
instance_type="ml.c5.xlarge",
203+
instance_type=instance_type,
204204
max_runtime_in_seconds=1200,
205205
sagemaker_session=sagemaker_session,
206206
)
@@ -237,14 +237,14 @@ def test_sagemaker_pyspark_sse_s3(role, image_uri, sagemaker_session, region, sa
237237

238238

239239
def test_sagemaker_pyspark_sse_kms_s3(
240-
role, image_uri, sagemaker_session, region, sagemaker_client, account_id, partition
240+
role, image_uri, sagemaker_session, region, sagemaker_client, account_id, partition, instance_type
241241
):
242242
spark = PySparkProcessor(
243243
base_job_name="sm-spark-py",
244244
image_uri=image_uri,
245245
role=role,
246246
instance_count=2,
247-
instance_type="ml.c5.xlarge",
247+
instance_type=instance_type,
248248
max_runtime_in_seconds=1200,
249249
sagemaker_session=sagemaker_session,
250250
)
@@ -301,14 +301,16 @@ def test_sagemaker_pyspark_sse_kms_s3(
301301
assert object_metadata["SSEKMSKeyId"] == f"arn:{partition}:kms:{region}:{account_id}:key/{kms_key_id}"
302302

303303

304-
def test_sagemaker_scala_jar_multinode(role, image_uri, configuration, sagemaker_session, sagemaker_client):
304+
def test_sagemaker_scala_jar_multinode(
305+
role, image_uri, configuration, sagemaker_session, sagemaker_client, instance_type
306+
):
305307
"""Test SparkJarProcessor using Scala application jar with external runtime dependency jars staged by SDK"""
306308
spark = SparkJarProcessor(
307309
base_job_name="sm-spark-scala",
308310
image_uri=image_uri,
309311
role=role,
310312
instance_count=2,
311-
instance_type="ml.c5.xlarge",
313+
instance_type=instance_type,
312314
max_runtime_in_seconds=1200,
313315
sagemaker_session=sagemaker_session,
314316
)
@@ -346,7 +348,14 @@ def test_sagemaker_scala_jar_multinode(role, image_uri, configuration, sagemaker
346348

347349

348350
def test_sagemaker_feature_store_ingestion_multinode(
349-
sagemaker_session, sagemaker_client, spark_version, framework_version, image_uri, role, is_feature_store_available
351+
sagemaker_session,
352+
sagemaker_client,
353+
spark_version,
354+
framework_version,
355+
image_uri,
356+
role,
357+
is_feature_store_available,
358+
instance_type,
350359
):
351360
"""Test FeatureStore use cases by ingesting data to feature group."""
352361

@@ -359,7 +368,7 @@ def test_sagemaker_feature_store_ingestion_multinode(
359368
image_uri=image_uri,
360369
role=role,
361370
instance_count=2,
362-
instance_type="ml.c5.xlarge",
371+
instance_type=instance_type,
363372
max_runtime_in_seconds=1200,
364373
sagemaker_session=sagemaker_session,
365374
)
@@ -383,15 +392,17 @@ def test_sagemaker_feature_store_ingestion_multinode(
383392
raise RuntimeError("Feature store Spark job stopped unexpectedly")
384393

385394

386-
def test_sagemaker_java_jar_multinode(tag, role, image_uri, configuration, sagemaker_session, sagemaker_client):
395+
def test_sagemaker_java_jar_multinode(
396+
tag, role, image_uri, configuration, sagemaker_session, sagemaker_client, instance_type
397+
):
387398
"""Test SparkJarProcessor using Java application jar"""
388399
spark = SparkJarProcessor(
389400
base_job_name="sm-spark-java",
390401
framework_version=tag,
391402
image_uri=image_uri,
392403
role=role,
393404
instance_count=2,
394-
instance_type="ml.c5.xlarge",
405+
instance_type=instance_type,
395406
max_runtime_in_seconds=1200,
396407
sagemaker_session=sagemaker_session,
397408
)

0 commit comments

Comments
 (0)