Skip to content

Commit 64e138b

Browse files
cj-zhangJoseph ZhangEC2 Default Usergwang111benieric
authored
feat: optimization technique related validations. (#4921)
* Enable quantization and compilation in the same optimization job via ModelBuilder and add validations to block compilation jobs using TRTLLM an Llama-3.1. * Require EULA acceptance when using a gated 1p draft model via ModelBuilder. * add accept_draft_model_eula to JumpStartModel when deployment config with gated draft model is selected * add map of valid optimization combinations * Add ModelBuilder support for JumpStart-provided draft models. * Tweak draft model EULA validations and messaging. Remove redundant deployment_config flow validation in optimize_utils in favor of the one directly on jumpstart/factory/model. * Add "Auto" speculative decoding ModelProvider option; add validations to differentiate SageMaker/JumpStart draft models. * Fix JumpStartModel.AdditionalModelDataSource model access config assignment. * move the accept eula configurations into deploy flow * move the accept eula configurations into deploy flow * Use correct bucket for SM/JS draft models and minor formatting/validation updates. * Remove obsolete docstring. * remove references to accept_draft_model_eula * renaming of eula fn and error msg * fix: pin testing deps (#4925) Co-authored-by: nileshvd <[email protected]> * Revert "change: add TGI 2.4.0 image uri (#4922)" (#4926) * fix naming and messaging * ModelBuilder speculative decoding UTs and minor fixes. * Fix set union. * add UTs for JumpStart deployment * fix formatting issues * address validation comments * fix doc strings * Add TRTLLM compilation + speculative decoding validation. * address nits --------- Co-authored-by: Joseph Zhang <[email protected]> Co-authored-by: EC2 Default User <[email protected]> Co-authored-by: Gary Wang 😤 <[email protected]> Co-authored-by: Gary Wang <[email protected]> Co-authored-by: Erick Benitez-Ramos <[email protected]> Co-authored-by: nileshvd <[email protected]> Co-authored-by: Haotian An <[email protected]>
1 parent 4b5659d commit 64e138b

File tree

14 files changed

+1537
-93
lines changed

14 files changed

+1537
-93
lines changed

src/sagemaker/jumpstart/factory/model.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717

1818
from typing import Any, Dict, List, Optional, Union
19+
from sagemaker_core.shapes import ModelAccessConfig
1920
from sagemaker import environment_variables, image_uris, instance_types, model_uris, script_uris
2021
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
2122
from sagemaker.base_deserializers import BaseDeserializer
@@ -53,11 +54,11 @@
5354
add_hub_content_arn_tags,
5455
add_jumpstart_model_info_tags,
5556
get_default_jumpstart_session_with_user_agent_suffix,
56-
get_neo_content_bucket,
5757
get_top_ranked_config_name,
5858
update_dict_if_key_not_present,
5959
resolve_model_sagemaker_config_field,
6060
verify_model_region_and_return_specs,
61+
get_draft_model_content_bucket,
6162
)
6263

6364
from sagemaker.jumpstart.factory.utils import (
@@ -70,7 +71,12 @@
7071

7172
from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig
7273
from sagemaker.session import Session
73-
from sagemaker.utils import camel_case_to_pascal_case, name_from_base, format_tags, Tags
74+
from sagemaker.utils import (
75+
camel_case_to_pascal_case,
76+
name_from_base,
77+
format_tags,
78+
Tags,
79+
)
7480
from sagemaker.workflow.entities import PipelineVariable
7581
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
7682
from sagemaker import resource_requirements
@@ -565,7 +571,9 @@ def _add_additional_model_data_sources_to_kwargs(
565571
# Append speculative decoding data source from metadata
566572
speculative_decoding_data_sources = specs.get_speculative_decoding_s3_data_sources()
567573
for data_source in speculative_decoding_data_sources:
568-
data_source.s3_data_source.set_bucket(get_neo_content_bucket(region=kwargs.region))
574+
data_source.s3_data_source.set_bucket(
575+
get_draft_model_content_bucket(provider=data_source.provider, region=kwargs.region)
576+
)
569577
api_shape_additional_model_data_sources = (
570578
[
571579
camel_case_to_pascal_case(data_source.to_json())
@@ -648,6 +656,7 @@ def get_deploy_kwargs(
648656
training_config_name: Optional[str] = None,
649657
config_name: Optional[str] = None,
650658
routing_config: Optional[Dict[str, Any]] = None,
659+
model_access_configs: Optional[Dict[str, ModelAccessConfig]] = None,
651660
) -> JumpStartModelDeployKwargs:
652661
"""Returns kwargs required to call `deploy` on `sagemaker.estimator.Model` object."""
653662

@@ -684,6 +693,7 @@ def get_deploy_kwargs(
684693
resources=resources,
685694
config_name=config_name,
686695
routing_config=routing_config,
696+
model_access_configs=model_access_configs,
687697
)
688698
deploy_kwargs, orig_session = _set_temp_sagemaker_session_if_not_set(kwargs=deploy_kwargs)
689699
deploy_kwargs.specs = verify_model_region_and_return_specs(

src/sagemaker/jumpstart/model.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import pandas as pd
1919
from botocore.exceptions import ClientError
2020

21+
from sagemaker_core.shapes import ModelAccessConfig
2122
from sagemaker import payloads
2223
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
2324
from sagemaker.base_deserializers import BaseDeserializer
@@ -51,6 +52,7 @@
5152
add_instance_rate_stats_to_benchmark_metrics,
5253
deployment_config_response_data,
5354
_deployment_config_lru_cache,
55+
_add_model_access_configs_to_model_data_sources,
5456
)
5557
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_LOGGER
5658
from sagemaker.jumpstart.enums import JumpStartModelType
@@ -540,12 +542,16 @@ def attach(
540542
inferred_model_id = inferred_model_version = inferred_inference_component_name = None
541543

542544
if inference_component_name is None or model_id is None or model_version is None:
543-
inferred_model_id, inferred_model_version, inferred_inference_component_name, _, _ = (
544-
get_model_info_from_endpoint(
545-
endpoint_name=endpoint_name,
546-
inference_component_name=inference_component_name,
547-
sagemaker_session=sagemaker_session,
548-
)
545+
(
546+
inferred_model_id,
547+
inferred_model_version,
548+
inferred_inference_component_name,
549+
_,
550+
_,
551+
) = get_model_info_from_endpoint(
552+
endpoint_name=endpoint_name,
553+
inference_component_name=inference_component_name,
554+
sagemaker_session=sagemaker_session,
549555
)
550556

551557
model_id = model_id or inferred_model_id
@@ -659,6 +665,7 @@ def deploy(
659665
managed_instance_scaling: Optional[str] = None,
660666
endpoint_type: EndpointType = EndpointType.MODEL_BASED,
661667
routing_config: Optional[Dict[str, Any]] = None,
668+
model_access_configs: Optional[Dict[str, ModelAccessConfig]] = None,
662669
) -> PredictorBase:
663670
"""Creates endpoint by calling base ``Model`` class `deploy` method.
664671
@@ -755,6 +762,11 @@ def deploy(
755762
(Default: EndpointType.MODEL_BASED).
756763
routing_config (Optional[Dict]): Settings the control how the endpoint routes
757764
incoming traffic to the instances that the endpoint hosts.
765+
model_access_configs (Optional[Dict[str, ModelAccessConfig]]): For models that require
766+
ModelAccessConfig, provide a `{ "model_id", ModelAccessConfig(accept_eula=True) }`
767+
to indicate whether model terms of use have been accepted. The `accept_eula` value
768+
must be explicitly defined as `True` in order to accept the end-user license
769+
agreement (EULA) that some models require. (Default: None)
758770
759771
Raises:
760772
MarketplaceModelSubscriptionError: If the caller is not subscribed to the model.
@@ -795,6 +807,7 @@ def deploy(
795807
model_type=self.model_type,
796808
config_name=self.config_name,
797809
routing_config=routing_config,
810+
model_access_configs=model_access_configs,
798811
)
799812
if (
800813
self.model_type == JumpStartModelType.PROPRIETARY
@@ -804,6 +817,13 @@ def deploy(
804817
f"{EndpointType.INFERENCE_COMPONENT_BASED} is not supported for Proprietary models."
805818
)
806819

820+
self.additional_model_data_sources = _add_model_access_configs_to_model_data_sources(
821+
self.additional_model_data_sources,
822+
deploy_kwargs.model_access_configs,
823+
deploy_kwargs.model_id,
824+
deploy_kwargs.region,
825+
)
826+
807827
try:
808828
predictor = super(JumpStartModel, self).deploy(**deploy_kwargs.to_kwargs_dict())
809829
except ClientError as e:
@@ -1016,10 +1036,11 @@ def _get_deployment_configs(
10161036
)
10171037

10181038
if metadata_config.benchmark_metrics:
1019-
err, metadata_config.benchmark_metrics = (
1020-
add_instance_rate_stats_to_benchmark_metrics(
1021-
self.region, metadata_config.benchmark_metrics
1022-
)
1039+
(
1040+
err,
1041+
metadata_config.benchmark_metrics,
1042+
) = add_instance_rate_stats_to_benchmark_metrics(
1043+
self.region, metadata_config.benchmark_metrics
10231044
)
10241045

10251046
config_components = metadata_config.config_components.get(config_name)

src/sagemaker/jumpstart/types.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from copy import deepcopy
1818
from enum import Enum
1919
from typing import Any, Dict, List, Optional, Set, Union
20+
from sagemaker_core.shapes import ModelAccessConfig as CoreModelAccessConfig
2021
from sagemaker.model_card.model_card import ModelCard, ModelPackageModelCard
2122
from sagemaker.utils import (
2223
S3_PREFIX,
@@ -1081,9 +1082,9 @@ def set_bucket(self, bucket: str) -> None:
10811082
class AdditionalModelDataSource(JumpStartDataHolderType):
10821083
"""Data class of additional model data source mirrors CreateModel API."""
10831084

1084-
SERIALIZATION_EXCLUSION_SET: Set[str] = set()
1085+
SERIALIZATION_EXCLUSION_SET = {"provider"}
10851086

1086-
__slots__ = ["channel_name", "s3_data_source"]
1087+
__slots__ = ["channel_name", "s3_data_source", "hosting_eula_key"]
10871088

10881089
def __init__(self, spec: Dict[str, Any]):
10891090
"""Initializes a AdditionalModelDataSource object.
@@ -1101,6 +1102,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
11011102
"""
11021103
self.channel_name: str = json_obj["channel_name"]
11031104
self.s3_data_source: S3DataSource = S3DataSource(json_obj["s3_data_source"])
1105+
self.hosting_eula_key: str = json_obj.get("hosting_eula_key")
1106+
self.provider: Dict = json_obj.get("provider", {})
11041107

11051108
def to_json(self, exclude_keys=True) -> Dict[str, Any]:
11061109
"""Returns json representation of AdditionalModelDataSource object."""
@@ -1119,7 +1122,9 @@ def to_json(self, exclude_keys=True) -> Dict[str, Any]:
11191122
class JumpStartModelDataSource(AdditionalModelDataSource):
11201123
"""Data class JumpStart additional model data source."""
11211124

1122-
SERIALIZATION_EXCLUSION_SET = {"artifact_version"}
1125+
SERIALIZATION_EXCLUSION_SET = AdditionalModelDataSource.SERIALIZATION_EXCLUSION_SET.union(
1126+
{"artifact_version"}
1127+
)
11231128

11241129
__slots__ = list(SERIALIZATION_EXCLUSION_SET) + AdditionalModelDataSource.__slots__
11251130

@@ -2239,6 +2244,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
22392244
"config_name",
22402245
"routing_config",
22412246
"specs",
2247+
"model_access_configs",
22422248
]
22432249

22442250
SERIALIZATION_EXCLUSION_SET = {
@@ -2252,6 +2258,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
22522258
"sagemaker_session",
22532259
"training_instance_type",
22542260
"config_name",
2261+
"model_access_configs",
22552262
}
22562263

22572264
def __init__(
@@ -2290,6 +2297,7 @@ def __init__(
22902297
endpoint_type: Optional[EndpointType] = None,
22912298
config_name: Optional[str] = None,
22922299
routing_config: Optional[Dict[str, Any]] = None,
2300+
model_access_configs: Optional[Dict[str, CoreModelAccessConfig]] = None,
22932301
) -> None:
22942302
"""Instantiates JumpStartModelDeployKwargs object."""
22952303

@@ -2327,6 +2335,7 @@ def __init__(
23272335
self.endpoint_type = endpoint_type
23282336
self.config_name = config_name
23292337
self.routing_config = routing_config
2338+
self.model_access_configs = model_access_configs
23302339

23312340

23322341
class JumpStartEstimatorInitKwargs(JumpStartKwargs):

src/sagemaker/jumpstart/utils.py

Lines changed: 91 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module contains utilities related to SageMaker JumpStart."""
1414
from __future__ import absolute_import
15+
1516
from copy import copy
1617
import logging
1718
import os
@@ -22,6 +23,7 @@
2223
from botocore.exceptions import ClientError
2324
from packaging.version import Version
2425
import botocore
26+
from sagemaker_core.shapes import ModelAccessConfig
2527
import sagemaker
2628
from sagemaker.config.config_schema import (
2729
MODEL_ENABLE_NETWORK_ISOLATION_PATH,
@@ -55,6 +57,7 @@
5557
TagsDict,
5658
get_instance_rate_per_hour,
5759
get_domain_for_region,
60+
camel_case_to_pascal_case,
5861
)
5962
from sagemaker.workflow import is_pipeline_variable
6063
from sagemaker.user_agent import get_user_agent_extra_suffix
@@ -555,11 +558,18 @@ def get_eula_message(model_specs: JumpStartModelSpecs, region: str) -> str:
555558
"""Returns EULA message to display if one is available, else empty string."""
556559
if model_specs.hosting_eula_key is None:
557560
return ""
561+
return get_formatted_eula_message_template(
562+
model_id=model_specs.model_id, region=region, hosting_eula_key=model_specs.hosting_eula_key
563+
)
564+
565+
566+
def get_formatted_eula_message_template(model_id: str, region: str, hosting_eula_key: str) -> str:
567+
"""Returns a formatted EULA message."""
558568
return (
559-
f"Model '{model_specs.model_id}' requires accepting end-user license agreement (EULA). "
569+
f"Model '{model_id}' requires accepting end-user license agreement (EULA). "
560570
f"See https://{get_jumpstart_content_bucket(region=region)}.s3.{region}."
561571
f"{get_domain_for_region(region)}"
562-
f"/{model_specs.hosting_eula_key} for terms of use."
572+
f"/{hosting_eula_key} for terms of use."
563573
)
564574

565575

@@ -1525,3 +1535,82 @@ def wrapped_f(*args, **kwargs):
15251535
if _func is None:
15261536
return wrapper_cache
15271537
return wrapper_cache(_func)
1538+
1539+
1540+
def _add_model_access_configs_to_model_data_sources(
1541+
model_data_sources: List[Dict[str, any]],
1542+
model_access_configs: Dict[str, ModelAccessConfig],
1543+
model_id: str,
1544+
region: str,
1545+
) -> List[Dict[str, any]]:
1546+
"""Iterate over the accept EULA configs to ensure all channels are matched
1547+
1548+
Args:
1549+
model_data_sources (DeploymentConfigMetadata): Model data sources that will be updated
1550+
model_access_configs (DeploymentConfigMetadata): Config holding accept_eula field
1551+
model_id (DeploymentConfigMetadata): Jumpstart model id.
1552+
region (str): Region where the user is operating in.
1553+
Returns:
1554+
List[Dict[str, Any]]: List of model data sources with accept EULA configs applied
1555+
Raise:
1556+
ValueError if at least one channel that requires EULA acceptance as not passed.
1557+
"""
1558+
if not model_data_sources:
1559+
return model_data_sources
1560+
1561+
acked_model_data_sources = []
1562+
for model_data_source in model_data_sources:
1563+
hosting_eula_key = model_data_source.get("HostingEulaKey")
1564+
mutable_model_data_source = model_data_source.copy()
1565+
if hosting_eula_key:
1566+
if (
1567+
not model_access_configs
1568+
or not model_access_configs.get(model_id)
1569+
or not model_access_configs.get(model_id).accept_eula
1570+
):
1571+
eula_message_template = (
1572+
"{model_source}{base_eula_message}{model_access_configs_message}"
1573+
)
1574+
model_access_config_entry = (
1575+
'"{model_id}":ModelAccessConfig(accept_eula=True)'.format(model_id=model_id)
1576+
)
1577+
raise ValueError(
1578+
eula_message_template.format(
1579+
model_source="Additional " if model_data_source.get("ChannelName") else "",
1580+
base_eula_message=get_formatted_eula_message_template(
1581+
model_id=model_id, region=region, hosting_eula_key=hosting_eula_key
1582+
),
1583+
model_access_configs_message=(
1584+
"Please add a ModelAccessConfig entry:"
1585+
f" {model_access_config_entry} "
1586+
"to model_access_configs to accept the EULA."
1587+
),
1588+
)
1589+
)
1590+
mutable_model_data_source.pop(
1591+
"HostingEulaKey"
1592+
) # pop when model access config is applied
1593+
mutable_model_data_source["S3DataSource"]["ModelAccessConfig"] = (
1594+
camel_case_to_pascal_case(model_access_configs.get(model_id).model_dump())
1595+
)
1596+
acked_model_data_sources.append(mutable_model_data_source)
1597+
else:
1598+
mutable_model_data_source.pop(
1599+
"HostingEulaKey"
1600+
) # pop when model access config is not applicable
1601+
acked_model_data_sources.append(mutable_model_data_source)
1602+
return acked_model_data_sources
1603+
1604+
1605+
def get_draft_model_content_bucket(provider: Dict, region: str) -> str:
1606+
"""Returns the correct content bucket for a 1p draft model."""
1607+
neo_bucket = get_neo_content_bucket(region=region)
1608+
if not provider:
1609+
return neo_bucket
1610+
provider_name = provider.get("name", "")
1611+
if provider_name == "JumpStart":
1612+
classification = provider.get("classification", "ungated")
1613+
if classification == "gated":
1614+
return get_jumpstart_gated_content_bucket(region=region)
1615+
return get_jumpstart_content_bucket(region=region)
1616+
return neo_bucket

0 commit comments

Comments
 (0)