33import tarfile
44import tempfile
55import time
6- from typing import Any , Dict , List , Optional , Tuple , Union
6+ from typing import Any , Dict , List , Optional , Union
77
88from .classification import Classification , Classifications
99from .embeddings import Embeddings
1010from .error import CohereError
11- from .generation import (Generation , Generations ,
12- StreamingGenerations ,
13- TokenLikelihood )
11+ from .generation import Generations , StreamingGenerations
1412from .chat import Chat , StreamingChat
1513from .rerank import Reranking
1614from .summary import Summary
@@ -135,7 +133,7 @@ def create_endpoint(
135133 instance_type (str, optional): The EC2 instance type to deploy the endpoint to. Defaults to "ml.g4dn.xlarge".
136134 n_instances (int, optional): Number of endpoint instances. Defaults to 1.
137135 recreate (bool, optional): Force re-creation of endpoint if it already exists. Defaults to False.
138- rool (str, optional): The IAM role to use for the endpoint. If not provided, sagemaker.get_execution_role()
136+ role (str, optional): The IAM role to use for the endpoint. If not provided, sagemaker.get_execution_role()
139137 will be used to get the role. This should work when one uses the client inside SageMaker. If this errors
140138 out, the default role "ServiceRoleSagemaker" will be used, which generally works outside of SageMaker.
141139 """
@@ -150,6 +148,7 @@ def create_endpoint(
150148 kwargs = {}
151149 model_data = None
152150 validation_params = dict ()
151+ useBoto = False
153152 if s3_models_dir is not None :
154153 # If s3_models_dir is given, we assume to have custom fine-tuned models -> Algorithm
155154 kwargs ["algorithm_arn" ] = arn
@@ -163,6 +162,7 @@ def create_endpoint(
163162 model_data_download_timeout = 2400 ,
164163 container_startup_health_check_timeout = 2400
165164 )
165+ useBoto = True
166166
167167 # Out of precaution, check if there is an endpoint config and delete it if that's the case
168168 # Otherwise it might block deployment
@@ -171,30 +171,80 @@ def create_endpoint(
171171 except lazy_botocore ().ClientError :
172172 pass
173173
174+ try :
175+ self ._service_client .delete_model (ModelName = endpoint_name )
176+ except lazy_botocore ().ClientError :
177+ pass
178+
174179 if role is None :
175- try :
176- role = lazy_sagemaker ().get_execution_role ()
177- except ValueError :
178- print ("Using default role: 'ServiceRoleSagemaker'." )
179- role = "ServiceRoleSagemaker"
180+ if useBoto :
181+ accountID = lazy_sagemaker ().account_id ()
182+ role = f"arn:aws:iam::{ accountID } :role/ServiceRoleSagemaker"
183+ else :
184+ try :
185+ role = lazy_sagemaker ().get_execution_role ()
186+ except ValueError :
187+ print ("Using default role: 'ServiceRoleSagemaker'." )
188+ role = "ServiceRoleSagemaker"
180189
181- model = lazy_sagemaker ().ModelPackage (
182- role = role ,
183- model_data = model_data ,
184- sagemaker_session = self ._sess , # makes sure the right region is used
185- ** kwargs
186- )
190+ # deploy fine-tuned model using sagemaker SDK
191+ if s3_models_dir is not None :
192+ model = lazy_sagemaker ().ModelPackage (
193+ role = role ,
194+ model_data = model_data ,
195+ sagemaker_session = self ._sess , # makes sure the right region is used
196+ ** kwargs
197+ )
187198
188- try :
189- model .deploy (
190- n_instances ,
191- instance_type ,
192- endpoint_name = endpoint_name ,
193- ** validation_params
199+ try :
200+ model .deploy (
201+ n_instances ,
202+ instance_type ,
203+ endpoint_name = endpoint_name ,
204+ ** validation_params
205+ )
206+ except lazy_botocore ().ParamValidationError :
207+ # For at least some versions of python 3.6, SageMaker SDK does not support the validation_params
208+ model .deploy (n_instances , instance_type , endpoint_name = endpoint_name )
209+ else :
210+ # deploy pre-trained model using boto to add InferenceAmiVersion
211+ self ._service_client .create_model (
212+ ModelName = endpoint_name ,
213+ ExecutionRoleArn = role ,
214+ EnableNetworkIsolation = True ,
215+ PrimaryContainer = {
216+ 'ModelPackageName' : arn ,
217+ },
218+ )
219+ self ._service_client .create_endpoint_config (
220+ EndpointConfigName = endpoint_name ,
221+ ProductionVariants = [
222+ {
223+ 'VariantName' : 'AllTraffic' ,
224+ 'ModelName' : endpoint_name ,
225+ 'InstanceType' : instance_type ,
226+ 'InitialInstanceCount' : n_instances ,
227+ 'InferenceAmiVersion' : 'al2-ami-sagemaker-inference-gpu-2'
228+ },
229+ ],
194230 )
195- except lazy_botocore ().ParamValidationError :
196- # For at least some versions of python 3.6, SageMaker SDK does not support the validation_params
197- model .deploy (n_instances , instance_type , endpoint_name = endpoint_name )
231+ self ._service_client .create_endpoint (
232+ EndpointName = endpoint_name ,
233+ EndpointConfigName = endpoint_name ,
234+ )
235+
236+ waiter = self ._service_client .get_waiter ('endpoint_in_service' )
237+ try :
238+ print (f"Waiting for endpoint { endpoint_name } to be in service..." )
239+ waiter .wait (
240+ EndpointName = endpoint_name ,
241+ WaiterConfig = {
242+ 'Delay' : 30 ,
243+ 'MaxAttempts' : 80
244+ }
245+ )
246+ except Exception as e :
247+ raise CohereError (f"Failed to create endpoint: { e } " )
198248 self .connect_to_endpoint (endpoint_name )
199249
200250 def chat (
@@ -725,12 +775,12 @@ def create_finetune(
725775 s3_resource = lazy_boto3 ().resource ("s3" )
726776
727777 # Copy new model to root of output_model_dir
728- bucket , old_key = parse_s3_url (current_filepath )
729- _ , new_key = parse_s3_url (f"{ s3_models_dir } { name } .tar.gz" )
778+ bucket , old_key = lazy_sagemaker (). s3 . parse_s3_url (current_filepath )
779+ _ , new_key = lazy_sagemaker (). s3 . parse_s3_url (f"{ s3_models_dir } { name } .tar.gz" )
730780 s3_resource .Object (bucket , new_key ).copy (CopySource = {"Bucket" : bucket , "Key" : old_key })
731781
732782 # Delete old dir
733- bucket , old_short_key = parse_s3_url (s3_models_dir + job_name )
783+ bucket , old_short_key = lazy_sagemaker (). s3 . parse_s3_url (s3_models_dir + job_name )
734784 s3_resource .Bucket (bucket ).objects .filter (Prefix = old_short_key ).delete ()
735785
736786 def export_finetune (
@@ -791,12 +841,12 @@ def export_finetune(
791841 s3_resource = lazy_boto3 ().resource ("s3" )
792842
793843 # Copy the exported TensorRT-LLM engine to the root of s3_output_dir
794- bucket , old_key = parse_s3_url (current_filepath )
795- _ , new_key = parse_s3_url (f"{ s3_output_dir } { name } .tar.gz" )
844+ bucket , old_key = lazy_sagemaker (). s3 . parse_s3_url (current_filepath )
845+ _ , new_key = lazy_sagemaker (). s3 . parse_s3_url (f"{ s3_output_dir } { name } .tar.gz" )
796846 s3_resource .Object (bucket , new_key ).copy (CopySource = {"Bucket" : bucket , "Key" : old_key })
797847
798848 # Delete the old S3 directory
799- bucket , old_short_key = parse_s3_url (f"{ s3_output_dir } { job_name } " )
849+ bucket , old_short_key = lazy_sagemaker (). s3 . parse_s3_url (f"{ s3_output_dir } { job_name } " )
800850 s3_resource .Bucket (bucket ).objects .filter (Prefix = old_short_key ).delete ()
801851
802852 def wait_for_finetune_job (self , job_id : str , timeout : int = 2 * 60 * 60 ) -> str :
0 commit comments