Skip to content

Commit 756515a

Browse files
CoderHamhemant-co
andauthored
create endpoint with InferenceAmiVersion (#602)
* x * pickup changes from cohere-ai/cohere-aws#196 * misc cleanup and import fixes --------- Co-authored-by: CoderHam <[email protected]>
1 parent 044344f commit 756515a

File tree

2 files changed

+82
-31
lines changed

2 files changed

+82
-31
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ dist/
33
__pycache__/
44
poetry.toml
55
.ruff_cache/
6+
.venv/

src/cohere/manually_maintained/cohere_aws/client.py

Lines changed: 81 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,12 @@
33
import tarfile
44
import tempfile
55
import time
6-
from typing import Any, Dict, List, Optional, Tuple, Union
6+
from typing import Any, Dict, List, Optional, Union
77

88
from .classification import Classification, Classifications
99
from .embeddings import Embeddings
1010
from .error import CohereError
11-
from .generation import (Generation, Generations,
12-
StreamingGenerations,
13-
TokenLikelihood)
11+
from .generation import Generations, StreamingGenerations
1412
from .chat import Chat, StreamingChat
1513
from .rerank import Reranking
1614
from .summary import Summary
@@ -135,7 +133,7 @@ def create_endpoint(
135133
instance_type (str, optional): The EC2 instance type to deploy the endpoint to. Defaults to "ml.g4dn.xlarge".
136134
n_instances (int, optional): Number of endpoint instances. Defaults to 1.
137135
recreate (bool, optional): Force re-creation of endpoint if it already exists. Defaults to False.
138-
rool (str, optional): The IAM role to use for the endpoint. If not provided, sagemaker.get_execution_role()
136+
role (str, optional): The IAM role to use for the endpoint. If not provided, sagemaker.get_execution_role()
139137
will be used to get the role. This should work when one uses the client inside SageMaker. If this errors
140138
out, the default role "ServiceRoleSagemaker" will be used, which generally works outside of SageMaker.
141139
"""
@@ -150,6 +148,7 @@ def create_endpoint(
150148
kwargs = {}
151149
model_data = None
152150
validation_params = dict()
151+
useBoto = False
153152
if s3_models_dir is not None:
154153
# If s3_models_dir is given, we assume to have custom fine-tuned models -> Algorithm
155154
kwargs["algorithm_arn"] = arn
@@ -163,6 +162,7 @@ def create_endpoint(
163162
model_data_download_timeout=2400,
164163
container_startup_health_check_timeout=2400
165164
)
165+
useBoto = True
166166

167167
# Out of precaution, check if there is an endpoint config and delete it if that's the case
168168
# Otherwise it might block deployment
@@ -171,30 +171,80 @@ def create_endpoint(
171171
except lazy_botocore().ClientError:
172172
pass
173173

174+
try:
175+
self._service_client.delete_model(ModelName=endpoint_name)
176+
except lazy_botocore().ClientError:
177+
pass
178+
174179
if role is None:
175-
try:
176-
role = lazy_sagemaker().get_execution_role()
177-
except ValueError:
178-
print("Using default role: 'ServiceRoleSagemaker'.")
179-
role = "ServiceRoleSagemaker"
180+
if useBoto:
181+
accountID = lazy_sagemaker().account_id()
182+
role = f"arn:aws:iam::{accountID}:role/ServiceRoleSagemaker"
183+
else:
184+
try:
185+
role = lazy_sagemaker().get_execution_role()
186+
except ValueError:
187+
print("Using default role: 'ServiceRoleSagemaker'.")
188+
role = "ServiceRoleSagemaker"
180189

181-
model = lazy_sagemaker().ModelPackage(
182-
role=role,
183-
model_data=model_data,
184-
sagemaker_session=self._sess, # makes sure the right region is used
185-
**kwargs
186-
)
190+
# deploy fine-tuned model using sagemaker SDK
191+
if s3_models_dir is not None:
192+
model = lazy_sagemaker().ModelPackage(
193+
role=role,
194+
model_data=model_data,
195+
sagemaker_session=self._sess, # makes sure the right region is used
196+
**kwargs
197+
)
187198

188-
try:
189-
model.deploy(
190-
n_instances,
191-
instance_type,
192-
endpoint_name=endpoint_name,
193-
**validation_params
199+
try:
200+
model.deploy(
201+
n_instances,
202+
instance_type,
203+
endpoint_name=endpoint_name,
204+
**validation_params
205+
)
206+
except lazy_botocore().ParamValidationError:
207+
# For at least some versions of python 3.6, SageMaker SDK does not support the validation_params
208+
model.deploy(n_instances, instance_type, endpoint_name=endpoint_name)
209+
else:
210+
# deploy pre-trained model using boto to add InferenceAmiVersion
211+
self._service_client.create_model(
212+
ModelName=endpoint_name,
213+
ExecutionRoleArn=role,
214+
EnableNetworkIsolation=True,
215+
PrimaryContainer={
216+
'ModelPackageName': arn,
217+
},
218+
)
219+
self._service_client.create_endpoint_config(
220+
EndpointConfigName=endpoint_name,
221+
ProductionVariants=[
222+
{
223+
'VariantName': 'AllTraffic',
224+
'ModelName': endpoint_name,
225+
'InstanceType': instance_type,
226+
'InitialInstanceCount': n_instances,
227+
'InferenceAmiVersion': 'al2-ami-sagemaker-inference-gpu-2'
228+
},
229+
],
194230
)
195-
except lazy_botocore().ParamValidationError:
196-
# For at least some versions of python 3.6, SageMaker SDK does not support the validation_params
197-
model.deploy(n_instances, instance_type, endpoint_name=endpoint_name)
231+
self._service_client.create_endpoint(
232+
EndpointName=endpoint_name,
233+
EndpointConfigName=endpoint_name,
234+
)
235+
236+
waiter = self._service_client.get_waiter('endpoint_in_service')
237+
try:
238+
print(f"Waiting for endpoint {endpoint_name} to be in service...")
239+
waiter.wait(
240+
EndpointName=endpoint_name,
241+
WaiterConfig={
242+
'Delay': 30,
243+
'MaxAttempts': 80
244+
}
245+
)
246+
except Exception as e:
247+
raise CohereError(f"Failed to create endpoint: {e}")
198248
self.connect_to_endpoint(endpoint_name)
199249

200250
def chat(
@@ -725,12 +775,12 @@ def create_finetune(
725775
s3_resource = lazy_boto3().resource("s3")
726776

727777
# Copy new model to root of output_model_dir
728-
bucket, old_key = parse_s3_url(current_filepath)
729-
_, new_key = parse_s3_url(f"{s3_models_dir}{name}.tar.gz")
778+
bucket, old_key = lazy_sagemaker().s3.parse_s3_url(current_filepath)
779+
_, new_key = lazy_sagemaker().s3.parse_s3_url(f"{s3_models_dir}{name}.tar.gz")
730780
s3_resource.Object(bucket, new_key).copy(CopySource={"Bucket": bucket, "Key": old_key})
731781

732782
# Delete old dir
733-
bucket, old_short_key = parse_s3_url(s3_models_dir + job_name)
783+
bucket, old_short_key = lazy_sagemaker().s3.parse_s3_url(s3_models_dir + job_name)
734784
s3_resource.Bucket(bucket).objects.filter(Prefix=old_short_key).delete()
735785

736786
def export_finetune(
@@ -791,12 +841,12 @@ def export_finetune(
791841
s3_resource = lazy_boto3().resource("s3")
792842

793843
# Copy the exported TensorRT-LLM engine to the root of s3_output_dir
794-
bucket, old_key = parse_s3_url(current_filepath)
795-
_, new_key = parse_s3_url(f"{s3_output_dir}{name}.tar.gz")
844+
bucket, old_key = lazy_sagemaker().s3.parse_s3_url(current_filepath)
845+
_, new_key = lazy_sagemaker().s3.parse_s3_url(f"{s3_output_dir}{name}.tar.gz")
796846
s3_resource.Object(bucket, new_key).copy(CopySource={"Bucket": bucket, "Key": old_key})
797847

798848
# Delete the old S3 directory
799-
bucket, old_short_key = parse_s3_url(f"{s3_output_dir}{job_name}")
849+
bucket, old_short_key = lazy_sagemaker().s3.parse_s3_url(f"{s3_output_dir}{job_name}")
800850
s3_resource.Bucket(bucket).objects.filter(Prefix=old_short_key).delete()
801851

802852
def wait_for_finetune_job(self, job_id: str, timeout: int = 2*60*60) -> str:

0 commit comments

Comments
 (0)