Skip to content

Commit 6bbf8e9

Browse files
Parameters as None by default (#299)
1 parent e7eff8b commit 6bbf8e9

File tree

3 files changed

+29
-22
lines changed

3 files changed

+29
-22
lines changed

aixplain/modules/model/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def run(
188188
data: Union[Text, Dict],
189189
name: Text = "model_process",
190190
timeout: float = 300,
191-
parameters: Optional[Dict] = {},
191+
parameters: Optional[Dict] = None,
192192
wait_time: float = 0.5,
193193
) -> Dict:
194194
"""Runs a model call.
@@ -197,7 +197,7 @@ def run(
197197
data (Union[Text, Dict]): link to the input data
198198
name (Text, optional): ID given to a call. Defaults to "model_process".
199199
timeout (float, optional): total polling time. Defaults to 300.
200-
parameters (Dict, optional): optional parameters to the model. Defaults to "{}".
200+
parameters (Dict, optional): optional parameters to the model. Defaults to None.
201201
wait_time (float, optional): wait time in seconds between polling calls. Defaults to 0.5.
202202
203203
Returns:
@@ -220,13 +220,13 @@ def run(
220220
response = {"status": "FAILED", "error": msg, "elapsed_time": end - start}
221221
return response
222222

223-
def run_async(self, data: Union[Text, Dict], name: Text = "model_process", parameters: Optional[Dict] = {}) -> Dict:
223+
def run_async(self, data: Union[Text, Dict], name: Text = "model_process", parameters: Optional[Dict] = None) -> Dict:
224224
"""Runs asynchronously a model call.
225225
226226
Args:
227227
data (Union[Text, Dict]): link to the input data
228228
name (Text, optional): ID given to a call. Defaults to "model_process".
229-
parameters (Dict, optional): optional parameters to the model. Defaults to "{}".
229+
parameters (Dict, optional): optional parameters to the model. Defaults to None.
230230
231231
Returns:
232232
dict: polling URL in response

aixplain/modules/model/llm_model.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def run(
102102
top_p: float = 1.0,
103103
name: Text = "model_process",
104104
timeout: float = 300,
105-
parameters: Optional[Dict] = {},
105+
parameters: Optional[Dict] = None,
106106
wait_time: float = 0.5,
107107
) -> Dict:
108108
"""Synchronously running a Large Language Model (LLM) model.
@@ -117,21 +117,23 @@ def run(
117117
top_p (float, optional): Top P. Defaults to 1.0.
118118
name (Text, optional): ID given to a call. Defaults to "model_process".
119119
timeout (float, optional): total polling time. Defaults to 300.
120-
parameters (Dict, optional): optional parameters to the model. Defaults to "{}".
120+
parameters (Dict, optional): optional parameters to the model. Defaults to None.
121121
wait_time (float, optional): wait time in seconds between polling calls. Defaults to 0.5.
122122
123123
Returns:
124124
Dict: parsed output from model
125125
"""
126126
start = time.time()
127+
if parameters is None:
128+
parameters = {}
127129
parameters.update(
128130
{
129-
"context": parameters["context"] if "context" in parameters else context,
130-
"prompt": parameters["prompt"] if "prompt" in parameters else prompt,
131-
"history": parameters["history"] if "history" in parameters else history,
132-
"temperature": parameters["temperature"] if "temperature" in parameters else temperature,
133-
"max_tokens": parameters["max_tokens"] if "max_tokens" in parameters else max_tokens,
134-
"top_p": parameters["top_p"] if "top_p" in parameters else top_p,
131+
"context": parameters.get("context", context),
132+
"prompt": parameters.get("prompt", prompt),
133+
"history": parameters.get("history", history),
134+
"temperature": parameters.get("temperature", temperature),
135+
"max_tokens": parameters.get("max_tokens", max_tokens),
136+
"top_p": parameters.get("top_p", top_p),
135137
}
136138
)
137139
payload = build_payload(data=data, parameters=parameters)
@@ -160,7 +162,7 @@ def run_async(
160162
max_tokens: int = 128,
161163
top_p: float = 1.0,
162164
name: Text = "model_process",
163-
parameters: Optional[Dict] = {},
165+
parameters: Optional[Dict] = None,
164166
) -> Dict:
165167
"""Runs asynchronously a model call.
166168
@@ -173,21 +175,23 @@ def run_async(
173175
max_tokens (int, optional): Maximum Generation Tokens. Defaults to 128.
174176
top_p (float, optional): Top P. Defaults to 1.0.
175177
name (Text, optional): ID given to a call. Defaults to "model_process".
176-
parameters (Dict, optional): optional parameters to the model. Defaults to "{}".
178+
parameters (Dict, optional): optional parameters to the model. Defaults to None.
177179
178180
Returns:
179181
dict: polling URL in response
180182
"""
181183
url = f"{self.url}/{self.id}"
182184
logging.debug(f"Model Run Async: Start service for {name} - {url}")
185+
if parameters is None:
186+
parameters = {}
183187
parameters.update(
184188
{
185-
"context": parameters["context"] if "context" in parameters else context,
186-
"prompt": parameters["prompt"] if "prompt" in parameters else prompt,
187-
"history": parameters["history"] if "history" in parameters else history,
188-
"temperature": parameters["temperature"] if "temperature" in parameters else temperature,
189-
"max_tokens": parameters["max_tokens"] if "max_tokens" in parameters else max_tokens,
190-
"top_p": parameters["top_p"] if "top_p" in parameters else top_p,
189+
"context": parameters.get("context", context),
190+
"prompt": parameters.get("prompt", prompt),
191+
"history": parameters.get("history", history),
192+
"temperature": parameters.get("temperature", temperature),
193+
"max_tokens": parameters.get("max_tokens", max_tokens),
194+
"top_p": parameters.get("top_p", top_p),
191195
}
192196
)
193197
payload = build_payload(data=data, parameters=parameters)

aixplain/modules/model/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33
import json
44
import logging
55
from aixplain.utils.file_utils import _request_with_retry
6-
from typing import Dict, Text, Union
6+
from typing import Dict, Text, Union, Optional
77

88

9-
def build_payload(data: Union[Text, Dict], parameters: Dict = {}):
9+
def build_payload(data: Union[Text, Dict], parameters: Optional[Dict] = None):
1010
from aixplain.factories import FileFactory
1111

12+
if parameters is None:
13+
parameters = {}
14+
1215
data = FileFactory.to_link(data)
1316
if isinstance(data, dict):
1417
payload = data

0 commit comments

Comments
 (0)