2424import json
2525import logging
2626from aixplain .modules .model import Model
27+ from aixplain .modules .model .llm_model import LLM
2728from aixplain .enums import Function , Language , OwnershipType , Supplier , SortBy , SortOrder
2829from aixplain .utils import config
2930from aixplain .utils .file_utils import _request_with_retry
@@ -60,13 +61,18 @@ def _create_model_from_response(cls, response: Dict) -> Model:
6061 if "language" in param ["name" ]:
6162 parameters [param ["name" ]] = [w ["value" ] for w in param ["values" ]]
6263
63- return Model (
64+ function = Function (response ["function" ]["id" ])
65+ ModelClass = Model
66+ if function == Function .TEXT_GENERATION :
67+ ModelClass = LLM
68+
69+ return ModelClass (
6470 response ["id" ],
6571 response ["name" ],
6672 supplier = response ["supplier" ],
6773 api_key = response ["api_key" ],
6874 cost = response ["pricing" ],
69- function = Function ( response [ " function" ][ "id" ]) ,
75+ function = function ,
7076 parameters = parameters ,
7177 is_subscribed = True if "subscription" in response else False ,
7278 version = response ["version" ]["id" ],
@@ -100,7 +106,7 @@ def get(cls, model_id: Text, api_key: Optional[Text] = None) -> Model:
100106 model = cls ._create_model_from_response (resp )
101107 logging .info (f"Model Creation: Model { model_id } instantiated." )
102108 return model
103- except Exception as e :
109+ except Exception :
104110 if resp is not None and "statusCode" in resp :
105111 status_code = resp ["statusCode" ]
106112 message = resp ["message" ]
@@ -135,7 +141,7 @@ def _get_assets_from_page(
135141 sort_order : SortOrder = SortOrder .ASCENDING ,
136142 ) -> List [Model ]:
137143 try :
138- url = urljoin (cls .backend_url , f "sdk/models/paginate" )
144+ url = urljoin (cls .backend_url , "sdk/models/paginate" )
139145 filter_params = {"q" : query , "pageNumber" : page_number , "pageSize" : page_size }
140146 if is_finetunable is not None :
141147 filter_params ["isFineTunable" ] = is_finetunable
@@ -253,7 +259,7 @@ def list_host_machines(cls, api_key: Optional[Text] = None) -> List[Dict]:
253259 List[Dict]: List of dictionaries containing information about
254260 each hosting machine.
255261 """
256- machines_url = urljoin (config .BACKEND_URL , f "sdk/hosting-machines" )
262+ machines_url = urljoin (config .BACKEND_URL , "sdk/hosting-machines" )
257263 logging .debug (f"URL: { machines_url } " )
258264 if api_key :
259265 headers = {"x-api-key" : f"{ api_key } " , "Content-Type" : "application/json" }
@@ -264,6 +270,25 @@ def list_host_machines(cls, api_key: Optional[Text] = None) -> List[Dict]:
264270 for dictionary in response_dicts :
265271 del dictionary ["id" ]
266272 return response_dicts
273+
274+ @classmethod
275+ def list_gpus (cls , api_key : Optional [Text ] = None ) -> List [List [Text ]]:
276+ """List GPU names on which you can host your language model.
277+
278+ Args:
279+ api_key (Text, optional): Team API key. Defaults to None.
280+
281+ Returns:
282+ List[List[Text]]: List of all available GPUs and their prices.
283+ """
284+ gpu_url = urljoin (config .BACKEND_URL , "sdk/model-onboarding/gpus" )
285+ if api_key :
286+ headers = {"Authorization" : f"Token { api_key } " , "Content-Type" : "application/json" }
287+ else :
288+ headers = {"Authorization" : f"Token { config .TEAM_API_KEY } " , "Content-Type" : "application/json" }
289+ response = _request_with_retry ("get" , gpu_url , headers = headers )
290+ response_list = json .loads (response .text )
291+ return response_list
267292
268293 @classmethod
269294 def list_functions (cls , verbose : Optional [bool ] = False , api_key : Optional [Text ] = None ) -> List [Dict ]:
@@ -278,7 +303,7 @@ def list_functions(cls, verbose: Optional[bool] = False, api_key: Optional[Text]
278303 List[Dict]: List of dictionaries containing information about
279304 each supported function.
280305 """
281- functions_url = urljoin (config .BACKEND_URL , f "sdk/functions" )
306+ functions_url = urljoin (config .BACKEND_URL , "sdk/functions" )
282307 logging .debug (f"URL: { functions_url } " )
283308 if api_key :
284309 headers = {"x-api-key" : f"{ api_key } " , "Content-Type" : "application/json" }
@@ -304,12 +329,13 @@ def list_functions(cls, verbose: Optional[bool] = False, api_key: Optional[Text]
304329 def create_asset_repo (
305330 cls ,
306331 name : Text ,
307- hosting_machine : Text ,
308- version : Text ,
309332 description : Text ,
310333 function : Text ,
311334 source_language : Text ,
312- api_key : Optional [Text ] = None ,
335+ input_modality : Text ,
336+ output_modality : Text ,
337+ documentation_url : Optional [Text ] = "" ,
338+ api_key : Optional [Text ] = None
313339 ) -> Dict :
314340 """Creates an image repository for this model and registers it in the
315341 platform backend.
@@ -336,27 +362,36 @@ def create_asset_repo(
336362 function_id = function_dict ["id" ]
337363 if function_id is None :
338364 raise Exception ("Invalid function name" )
339- create_url = urljoin (config .BACKEND_URL , f"sdk/models/register " )
365+ create_url = urljoin (config .BACKEND_URL , f"sdk/models/onboard " )
340366 logging .debug (f"URL: { create_url } " )
341367 if api_key :
342368 headers = {"x-api-key" : f"{ api_key } " , "Content-Type" : "application/json" }
343369 else :
344370 headers = {"x-api-key" : f"{ config .TEAM_API_KEY } " , "Content-Type" : "application/json" }
345- always_on = False
346- is_async = False # Hard-coded to False for first release
371+
347372 payload = {
348- "name" : name ,
349- "hostingMachine" : hosting_machine ,
350- "alwaysOn" : always_on ,
351- "version" : version ,
352- "description" : description ,
353- "function" : function_id ,
354- "isAsync" : is_async ,
355- "sourceLanguage" : source_language ,
373+ "model" : {
374+ "name" : name ,
375+ "description" : description ,
376+ "connectionType" : [
377+ "synchronous"
378+ ],
379+ "function" : function_id ,
380+ "modalities" : [
381+ f"{ input_modality } -{ output_modality } "
382+ ],
383+ "documentationUrl" : documentation_url ,
384+ "sourceLanguage" : source_language
385+ },
386+ "source" : "aixplain-ecr" ,
387+ "onboardingParams" : {
388+ }
356389 }
357- payload = json .dumps (payload )
358390 logging .debug (f"Body: { str (payload )} " )
359- response = _request_with_retry ("post" , create_url , headers = headers , data = payload )
391+ response = _request_with_retry ("post" , create_url , headers = headers , json = payload )
392+
393+ assert response .status_code == 201
394+
360395 return response .json ()
361396
362397 @classmethod
@@ -370,23 +405,26 @@ def asset_repo_login(cls, api_key: Optional[Text] = None) -> Dict:
370405 Returns:
371406 Dict: Backend response
372407 """
373- login_url = urljoin (config .BACKEND_URL , f "sdk/ecr/login" )
408+ login_url = urljoin (config .BACKEND_URL , "sdk/ecr/login" )
374409 logging .debug (f"URL: { login_url } " )
375410 if api_key :
376- headers = {"x-api-key " : f"{ api_key } " , "Content-Type" : "application/json" }
411+ headers = {"Authorization " : f"Token { api_key } " , "Content-Type" : "application/json" }
377412 else :
378- headers = {"x-api-key " : f"{ config .TEAM_API_KEY } " , "Content-Type" : "application/json" }
413+ headers = {"Authorization " : f"Token { config .TEAM_API_KEY } " , "Content-Type" : "application/json" }
379414 response = _request_with_retry ("post" , login_url , headers = headers )
415+ print (f"Response: { response } " )
380416 response_dict = json .loads (response .text )
381417 return response_dict
382418
383419 @classmethod
384- def onboard_model (cls , model_id : Text , image_tag : Text , image_hash : Text , api_key : Optional [Text ] = None ) -> Dict :
420+ def onboard_model (cls , model_id : Text , image_tag : Text , image_hash : Text , host_machine : Optional [ Text ] = "" , api_key : Optional [Text ] = None ) -> Dict :
385421 """Onboard a model after its image has been pushed to ECR.
386422
387423 Args:
388424 model_id (Text): Model ID obtained from CREATE_ASSET_REPO.
389425 image_tag (Text): Image tag to be onboarded.
426+ image_hash (Text): Image digest.
427+ host_machine (Text, optional): Machine on which to host model.
390428 api_key (Text, optional): Team API key. Defaults to None.
391429 Returns:
392430 Dict: Backend response
@@ -397,18 +435,18 @@ def onboard_model(cls, model_id: Text, image_tag: Text, image_hash: Text, api_ke
397435 headers = {"x-api-key" : f"{ api_key } " , "Content-Type" : "application/json" }
398436 else :
399437 headers = {"x-api-key" : f"{ config .TEAM_API_KEY } " , "Content-Type" : "application/json" }
400- payload = {"image" : image_tag , "sha" : image_hash }
401- payload = json .dumps (payload )
438+ payload = {"image" : image_tag , "sha" : image_hash , "hostMachine" : host_machine }
402439 logging .debug (f"Body: { str (payload )} " )
403- response = _request_with_retry ("post" , onboard_url , headers = headers , data = payload )
404- message = "Your onboarding request has been submitted to an aiXplain specialist for finalization. We will notify you when the process is completed."
405- logging .info (message )
440+ response = _request_with_retry ("post" , onboard_url , headers = headers , json = payload )
441+ if response .status_code == 201 :
442+ message = "Your onboarding request has been submitted to an aiXplain specialist for finalization. We will notify you when the process is completed."
443+ logging .info (message )
444+ else :
445+ message = "An error has occurred. Please make sure your model_id is valid and your host_machine, if set, is a valid option from the LIST_GPUS function."
406446 return response
407447
408448 @classmethod
409- def deploy_huggingface_model (
410- cls , name : Text , hf_repo_id : Text , hf_token : Optional [Text ] = "" , api_key : Optional [Text ] = None
411- ) -> Dict :
449+ def deploy_huggingface_model (cls , name : Text , hf_repo_id : Text , revision : Optional [Text ] = "" , hf_token : Optional [Text ] = "" , api_key : Optional [Text ] = None ) -> Dict :
412450 """Onboards and deploys a Hugging Face large language model.
413451
414452 Args:
@@ -420,7 +458,7 @@ def deploy_huggingface_model(
420458 Dict: Backend response
421459 """
422460 supplier , model_name = hf_repo_id .split ("/" )
423- deploy_url = urljoin (config .BACKEND_URL , f "sdk/model-onboarding/onboard" )
461+ deploy_url = urljoin (config .BACKEND_URL , "sdk/model-onboarding/onboard" )
424462 if api_key :
425463 headers = {"Authorization" : f"Token { api_key } " , "Content-Type" : "application/json" }
426464 else :
@@ -435,7 +473,12 @@ def deploy_huggingface_model(
435473 "sourceLanguage" : "en" ,
436474 },
437475 "source" : "huggingface" ,
438- "onboardingParams" : {"hf_model_name" : model_name , "hf_supplier" : supplier , "hf_token" : hf_token },
476+ "onboardingParams" : {
477+ "hf_supplier" : supplier ,
478+ "hf_model_name" : model_name ,
479+ "hf_token" : hf_token ,
480+ "revision" : revision
481+ }
439482 }
440483 response = _request_with_retry ("post" , deploy_url , headers = headers , json = body )
441484 logging .debug (response .text )
0 commit comments