Skip to content

Commit f961244

Browse files
jtblackjtblack-aws
authored andcommitted
feat: add a test multi-container endpoint
1 parent 183175e commit f961244

File tree

2 files changed

+145
-22
lines changed

2 files changed

+145
-22
lines changed

lib/osml/model_endpoint/me_sm_endpoint.ts

Lines changed: 85 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,33 @@ import { Construct } from "constructs";
1111

1212
import { BaseConfig, ConfigType } from "../utils/base_config";
1313

14+
/**
15+
* Configuration for a container in a SageMaker model
16+
*/
17+
export interface ContainerDefinition {
18+
/**
19+
* The URI of the container image
20+
*/
21+
imageUri: string;
22+
23+
/**
24+
* Environment variables for the container
25+
*/
26+
environment?: Record<string, unknown>;
27+
28+
/**
29+
* Repository access mode for the container
30+
*/
31+
repositoryAccessMode?: string;
32+
33+
/**
34+
* Hostname for the container (required for multi-container Direct mode)
35+
* If required and not specified, will default to an index-based name
36+
* example: "container-0"
37+
*/
38+
containerHostname?: string;
39+
}
40+
1441
/**
1542
* Configuration class for MESMEndpoint Construct.
1643
*/
@@ -38,13 +65,20 @@ export class MESMEndpointConfig extends BaseConfig {
3865
*/
3966
public SECURITY_GROUP_ID: string;
4067

68+
/**
69+
* List of container definitions for the model
70+
*/
71+
public CONTAINERS: ContainerDefinition[];
72+
4173
/**
4274
* A JSON object which includes ENV variables to be put into the model container.
75+
* @deprecated Use CONTAINERS (ContainerDefinition[]) instead
4376
*/
4477
public CONTAINER_ENV: Record<string, unknown>;
4578

4679
/**
4780
* The repository access mode to use for the SageMaker endpoint container.
81+
* @deprecated Use CONTAINERS (ContainerDefinition) instead
4882
*/
4983
public REPOSITORY_ACCESS_MODE: string;
5084
/**
@@ -57,9 +91,30 @@ export class MESMEndpointConfig extends BaseConfig {
5791
INITIAL_VARIANT_WEIGHT: 1,
5892
INITIAL_INSTANCE_COUNT: 1,
5993
VARIANT_NAME: "AllTraffic",
60-
REPOSITORY_ACCESS_MODE: "Platform",
94+
CONTAINERS: [],
6195
...config
6296
});
97+
98+
// Convert deprecated interface to container list if needed
99+
if (this.CONTAINERS.length === 0 && config.CONTAINER_ENV !== undefined) {
100+
this.CONTAINERS = [
101+
{
102+
imageUri: "", // Populated later with props.containerImageUri
103+
environment: config.CONTAINER_ENV as Record<string, unknown>,
104+
repositoryAccessMode: (config.REPOSITORY_ACCESS_MODE ||
105+
"Platform") as string
106+
}
107+
];
108+
} else if (this.CONTAINERS.length === 0) {
109+
// Ensure we always have a CONTAINERS array - default to an empty container definition
110+
this.CONTAINERS = [
111+
{
112+
imageUri: "",
113+
environment: {} as Record<string, unknown>,
114+
repositoryAccessMode: "Platform"
115+
}
116+
];
117+
}
63118
}
64119
}
65120

@@ -151,26 +206,35 @@ export class MESMEndpoint extends Construct {
151206
this.config = [props.config];
152207
}
153208

154-
const models = this.config.map(
155-
(config) =>
156-
new CfnModel(this, `${id}-${config.VARIANT_NAME}`, {
157-
executionRoleArn: props.roleArn,
158-
containers: [
159-
{
160-
image: props.containerImageUri,
161-
environment: config.CONTAINER_ENV,
162-
imageConfig: {
163-
repositoryAccessMode:
164-
config.REPOSITORY_ACCESS_MODE || "Platform"
165-
}
166-
}
167-
],
168-
vpcConfig: {
169-
subnets: props.subnetIds,
170-
securityGroupIds: [config.SECURITY_GROUP_ID]
171-
}
172-
})
173-
);
209+
const models = this.config.map((config) => {
210+
// Set the imageUri for containers that don't have one specified. This
211+
// handles the legacy conversion case where imageUri was initially empty.
212+
config.CONTAINERS = config.CONTAINERS.map((container) => ({
213+
...container,
214+
imageUri: container.imageUri || props.containerImageUri
215+
}));
216+
217+
// Map to the SageMaker container format
218+
const containers = config.CONTAINERS.map((container, index) => ({
219+
image: container.imageUri,
220+
environment: container.environment || {},
221+
imageConfig: {
222+
repositoryAccessMode: container.repositoryAccessMode || "Platform"
223+
},
224+
containerHostname: container.containerHostname || `container-${index}`
225+
}));
226+
227+
return new CfnModel(this, `${id}-${config.VARIANT_NAME}`, {
228+
executionRoleArn: props.roleArn,
229+
containers: containers,
230+
inferenceExecutionConfig:
231+
containers.length > 1 ? { mode: "Direct" } : undefined,
232+
vpcConfig: {
233+
subnets: props.subnetIds,
234+
securityGroupIds: [config.SECURITY_GROUP_ID]
235+
}
236+
});
237+
});
174238

175239
this.endpointConfig = new CfnEndpointConfig(this, `${id}-EndpointConfig`, {
176240
productionVariants: this.config.map((config, i) => ({

lib/osml/model_endpoint/me_test_endpoints.ts

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,12 @@ export class METestEndpointsConfig extends BaseConfig {
7373
*/
7474
public DEPLOY_SM_FLOOD_ENDPOINT: boolean;
7575

76+
/**
77+
* Whether to deploy the SageMaker multi-container model endpoint.
78+
* @default true
79+
*/
80+
public DEPLOY_MULTI_CONTAINER_ENDPOINT: boolean;
81+
7682
/**
7783
* The CPU allocation for the HTTP endpoint.
7884
* @default 4096
@@ -138,6 +144,12 @@ export class METestEndpointsConfig extends BaseConfig {
138144
*/
139145
public SM_CENTER_POINT_MODEL: string;
140146

147+
/**
148+
* The name of the multi-container SageMaker endpoint.
149+
* @default "multi-container"
150+
*/
151+
public SM_MULTI_CONTAINER_ENDPOINT: string;
152+
141153
/**
142154
* The SageMaker CPU instance type.
143155
* @default "ml.m5.xlarge"
@@ -176,6 +188,7 @@ export class METestEndpointsConfig extends BaseConfig {
176188
DEPLOY_SM_AIRCRAFT_ENDPOINT: true,
177189
DEPLOY_SM_CENTERPOINT_ENDPOINT: true,
178190
DEPLOY_SM_FLOOD_ENDPOINT: true,
191+
DEPLOY_MULTI_CONTAINER_ENDPOINT: true,
179192
HTTP_ENDPOINT_CPU: 4096,
180193
HTTP_ENDPOINT_CONTAINER_PORT: 8080,
181194
HTTP_ENDPOINT_DOMAIN_NAME: "test-http-model-endpoint",
@@ -185,8 +198,9 @@ export class METestEndpointsConfig extends BaseConfig {
185198
HTTP_ENDPOINT_MEMORY: 16384,
186199
SM_AIRCRAFT_MODEL: "aircraft",
187200
SM_CENTER_POINT_MODEL: "centerpoint",
188-
SM_CPU_INSTANCE_TYPE: "ml.m5.xlarge",
189201
SM_FLOOD_MODEL: "flood",
202+
SM_MULTI_CONTAINER_ENDPOINT: "multi-container",
203+
SM_CPU_INSTANCE_TYPE: "ml.m5.xlarge",
190204
...config
191205
});
192206
}
@@ -271,6 +285,11 @@ export class METestEndpoints extends Construct {
271285
*/
272286
public aircraftModelEndpoint?: MESMEndpoint;
273287

288+
/**
289+
* SM endpoint for testing a multi-container configuration.
290+
*/
291+
public multiContainerModelEndpoint?: MESMEndpoint;
292+
274293
/**
275294
* Security Group ID associated with the endpoints.
276295
*/
@@ -474,5 +493,45 @@ export class METestEndpoints extends Construct {
474493
);
475494
this.aircraftModelEndpoint.node.addDependency(this.modelContainer);
476495
}
496+
497+
// Build a multi-container endpoint
498+
if (this.config.DEPLOY_MULTI_CONTAINER_ENDPOINT) {
499+
this.multiContainerModelEndpoint = new MESMEndpoint(
500+
this,
501+
"OSMLMultiContainerModelEndpoint",
502+
{
503+
containerImageUri: this.modelContainer.containerUri,
504+
modelName: this.config.SM_MULTI_CONTAINER_ENDPOINT,
505+
roleArn: this.smRole.roleArn,
506+
instanceType: this.config.SM_CPU_INSTANCE_TYPE,
507+
subnetIds: props.osmlVpc.selectedSubnets.subnetIds,
508+
config: [
509+
new MESMEndpointConfig({
510+
SECURITY_GROUP_ID: this.securityGroupId,
511+
CONTAINERS: [
512+
{
513+
imageUri: this.modelContainer.containerUri,
514+
environment: {
515+
MODEL_SELECTION: this.config.SM_CENTER_POINT_MODEL
516+
},
517+
repositoryAccessMode:
518+
this.modelContainer.repositoryAccessMode,
519+
containerHostname: "centerpoint-container"
520+
},
521+
{
522+
imageUri: this.modelContainer.containerUri,
523+
environment: {
524+
MODEL_SELECTION: this.config.SM_AIRCRAFT_MODEL
525+
},
526+
repositoryAccessMode:
527+
this.modelContainer.repositoryAccessMode,
528+
containerHostname: "aircraft-container"
529+
}
530+
]
531+
})
532+
]
533+
}
534+
);
535+
}
477536
}
478537
}

0 commit comments

Comments
 (0)