2020Description:
2121 Model Factory Class
2222"""
23- from typing import Dict , List , Optional , Text , Tuple , Union
23+ from typing import Callable , Dict , List , Optional , Text , Tuple , Union
2424import json
2525import logging
2626from aixplain .modules .model import Model
27- from aixplain .modules .model .llm_model import LLM
27+ from aixplain .modules .model .utility_model import UtilityModel , UtilityModelInput
2828from aixplain .enums import Function , Language , OwnershipType , Supplier , SortBy , SortOrder
2929from aixplain .utils import config
3030from aixplain .utils .file_utils import _request_with_retry
3131from urllib .parse import urljoin
32- from warnings import warn
33- from aixplain .enums .function import FunctionInputOutput
34- from datetime import datetime
3532
3633
3734class ModelFactory :
@@ -44,53 +41,58 @@ class ModelFactory:
4441 backend_url = config .BACKEND_URL
4542
4643 @classmethod
47- def _create_model_from_response (cls , response : Dict ) -> Model :
48- """Converts response Json to 'Model' object
44+ def create_utility_model (
45+ cls ,
46+ name : Text ,
47+ code : Union [Text , Callable ],
48+ inputs : List [UtilityModelInput ] = [],
49+ description : Optional [Text ] = None ,
50+ output_examples : Text = "" ,
51+ ) -> UtilityModel :
52+ """Create a utility model
4953
5054 Args:
51- response (Dict): Json from API
55+ name (Text): name of the model
56+ code (Union[Text, Callable]): code of the model
57+ description (Text, optional): description of the model
58+ inputs (List[UtilityModelInput], optional): inputs of the model
59+ output_examples (Text, optional): output examples
5260
5361 Returns:
54- Model: Coverted 'Model' object
62+ UtilityModel: created utility model
5563 """
56- if "api_key" not in response :
57- response ["api_key" ] = config .TEAM_API_KEY
58-
59- parameters = {}
60- if "params" in response :
61- for param in response ["params" ]:
62- if "language" in param ["name" ]:
63- parameters [param ["name" ]] = [w ["value" ] for w in param ["values" ]]
64-
65- function = Function (response ["function" ]["id" ])
66- ModelClass = Model
67- if function == Function .TEXT_GENERATION :
68- ModelClass = LLM
69-
70- created_at = None
71- if "createdAt" in response and response ["createdAt" ]:
72- created_at = datetime .fromisoformat (response ["createdAt" ].replace ("Z" , "+00:00" ))
73- function_id = response ["function" ]["id" ]
74- function = Function (function_id )
75- function_io = FunctionInputOutput .get (function_id , None )
76- input_params = {param ["code" ]: param for param in function_io ["spec" ]["params" ]}
77- output_params = {param ["code" ]: param for param in function_io ["spec" ]["output" ]}
78-
79- return ModelClass (
80- response ["id" ],
81- response ["name" ],
82- description = response .get ("description" , "" ),
83- supplier = response ["supplier" ],
84- api_key = response ["api_key" ],
85- cost = response ["pricing" ],
86- function = function ,
87- created_at = created_at ,
88- parameters = parameters ,
89- input_params = input_params ,
90- output_params = output_params ,
91- is_subscribed = True if "subscription" in response else False ,
92- version = response ["version" ]["id" ],
64+ utility_model = UtilityModel (
65+ id = "" ,
66+ name = name ,
67+ description = description ,
68+ inputs = inputs ,
69+ code = code ,
70+ function = Function .UTILITIES ,
71+ api_key = config .TEAM_API_KEY ,
72+ output_examples = output_examples ,
9373 )
74+ utility_model .validate ()
75+ payload = utility_model .to_dict ()
76+ url = urljoin (cls .backend_url , "sdk/utilities" )
77+ headers = {"x-api-key" : f"{ config .TEAM_API_KEY } " , "Content-Type" : "application/json" }
78+ try :
79+ logging .info (f"Start service for POST Utility Model - { url } - { headers } - { payload } " )
80+ r = _request_with_retry ("post" , url , headers = headers , json = payload )
81+ resp = r .json ()
82+ except Exception as e :
83+ logging .error (f"Error creating utility model: { e } " )
84+ raise e
85+
86+ if 200 <= r .status_code < 300 :
87+ utility_model .id = resp ["id" ]
88+ logging .info (f"Utility Model Creation: Model { utility_model .id } instantiated." )
89+ return utility_model
90+ else :
91+ error_message = (
92+ f"Utility Model Creation: Failed to create utility model. Status Code: { r .status_code } . Error: { resp } "
93+ )
94+ logging .error (error_message )
95+ raise Exception (error_message )
9496
9597 @classmethod
9698 def get (cls , model_id : Text , api_key : Optional [Text ] = None ) -> Model :
@@ -125,95 +127,16 @@ def get(cls, model_id: Text, api_key: Optional[Text] = None) -> Model:
125127 resp ["api_key" ] = config .TEAM_API_KEY
126128 if api_key is not None :
127129 resp ["api_key" ] = api_key
128- model = cls ._create_model_from_response (resp )
130+ from aixplain .factories .model_factory .utils import create_model_from_response
131+
132+ model = create_model_from_response (resp )
129133 logging .info (f"Model Creation: Model { model_id } instantiated." )
130134 return model
131135 else :
132136 error_message = f"Model GET Error: Failed to retrieve model { model_id } . Status Code: { r .status_code } . Error: { resp } "
133137 logging .error (error_message )
134138 raise Exception (error_message )
135139
136- @classmethod
137- def create_asset_from_id (cls , model_id : Text ) -> Model :
138- warn (
139- 'This method will be deprecated in the next versions of the SDK. Use "get" instead.' ,
140- DeprecationWarning ,
141- stacklevel = 2 ,
142- )
143- return cls .get (model_id )
144-
145- @classmethod
146- def _get_assets_from_page (
147- cls ,
148- query ,
149- page_number : int ,
150- page_size : int ,
151- function : Function ,
152- suppliers : Union [Supplier , List [Supplier ]],
153- source_languages : Union [Language , List [Language ]],
154- target_languages : Union [Language , List [Language ]],
155- is_finetunable : bool = None ,
156- ownership : Optional [Tuple [OwnershipType , List [OwnershipType ]]] = None ,
157- sort_by : Optional [SortBy ] = None ,
158- sort_order : SortOrder = SortOrder .ASCENDING ,
159- ) -> List [Model ]:
160- try :
161- url = urljoin (cls .backend_url , "sdk/models/paginate" )
162- filter_params = {"q" : query , "pageNumber" : page_number , "pageSize" : page_size }
163- if is_finetunable is not None :
164- filter_params ["isFineTunable" ] = is_finetunable
165- if function is not None :
166- filter_params ["functions" ] = [function .value ]
167- if suppliers is not None :
168- if isinstance (suppliers , Supplier ) is True :
169- suppliers = [suppliers ]
170- filter_params ["suppliers" ] = [supplier .value ["id" ] for supplier in suppliers ]
171- if ownership is not None :
172- if isinstance (ownership , OwnershipType ) is True :
173- ownership = [ownership ]
174- filter_params ["ownership" ] = [ownership_ .value for ownership_ in ownership ]
175-
176- lang_filter_params = []
177- if source_languages is not None :
178- if isinstance (source_languages , Language ):
179- source_languages = [source_languages ]
180- if function == Function .TRANSLATION :
181- lang_filter_params .append ({"code" : "sourcelanguage" , "value" : source_languages [0 ].value ["language" ]})
182- else :
183- lang_filter_params .append ({"code" : "language" , "value" : source_languages [0 ].value ["language" ]})
184- if source_languages [0 ].value ["dialect" ] != "" :
185- lang_filter_params .append ({"code" : "dialect" , "value" : source_languages [0 ].value ["dialect" ]})
186- if target_languages is not None :
187- if isinstance (target_languages , Language ):
188- target_languages = [target_languages ]
189- if function == Function .TRANSLATION :
190- code = "targetlanguage"
191- lang_filter_params .append ({"code" : code , "value" : target_languages [0 ].value ["language" ]})
192- if sort_by is not None :
193- filter_params ["sort" ] = [{"dir" : sort_order .value , "field" : sort_by .value }]
194- if len (lang_filter_params ) != 0 :
195- filter_params ["ioFilter" ] = lang_filter_params
196-
197- headers = {"Authorization" : f"Token { config .TEAM_API_KEY } " , "Content-Type" : "application/json" }
198-
199- logging .info (f"Start service for POST Models Paginate - { url } - { headers } - { json .dumps (filter_params )} " )
200- r = _request_with_retry ("post" , url , headers = headers , json = filter_params )
201- resp = r .json ()
202-
203- except Exception as e :
204- error_message = f"Listing Models: Error in getting Models on Page { page_number } : { e } "
205- logging .error (error_message , exc_info = True )
206- return []
207- if 200 <= r .status_code < 300 :
208- logging .info (f"Listing Models: Status of getting Models on Page { page_number } : { r .status_code } " )
209- all_models = resp ["items" ]
210- model_list = [cls ._create_model_from_response (model_info_json ) for model_info_json in all_models ]
211- return model_list , resp ["total" ]
212- else :
213- error_message = f"Listing Models Error: Failed to retrieve models. Status Code: { r .status_code } . Error: { resp } "
214- logging .error (error_message )
215- raise Exception (error_message )
216-
217140 @classmethod
218141 def list (
219142 cls ,
@@ -244,7 +167,9 @@ def list(
244167 Returns:
245168 List[Model]: List of models based on given filters
246169 """
247- models , total = cls ._get_assets_from_page (
170+ from aixplain .factories .model_factory .utils import get_assets_from_page
171+
172+ models , total = get_assets_from_page (
248173 query ,
249174 page_number ,
250175 page_size ,
0 commit comments