Skip to content

Commit bed9312

Browse files
authored
Update pyproject.toml (#659) (#662)
* added get by name method for agents * use cache set to true by default
1 parent efb0e90 commit bed9312

File tree

9 files changed

+100
-42
lines changed

9 files changed

+100
-42
lines changed

aixplain/factories/agent_factory/__init__.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,10 @@ def create(
9898
from aixplain.utils.llm_utils import get_llm_instance
9999

100100
if llm is None and llm_id is not None:
101-
llm = get_llm_instance(llm_id, api_key=api_key)
101+
llm = get_llm_instance(llm_id, api_key=api_key, use_cache=True)
102102
elif llm is None:
103103
# Use default GPT-4o if no LLM specified
104-
llm = get_llm_instance("669a63646eb56306647e1091", api_key=api_key)
104+
llm = get_llm_instance("669a63646eb56306647e1091", api_key=api_key, use_cache=True)
105105

106106
if output_format == OutputFormat.JSON:
107107
assert expected_output is not None and (
@@ -152,7 +152,7 @@ def create(
152152
}
153153

154154
if llm is not None:
155-
llm = get_llm_instance(llm, api_key=api_key) if isinstance(llm, str) else llm
155+
llm = get_llm_instance(llm, api_key=api_key, use_cache=True) if isinstance(llm, str) else llm
156156
payload["tools"].append(
157157
{
158158
"type": "llm",
@@ -519,11 +519,12 @@ def list(cls) -> Dict:
519519
raise Exception(error_msg)
520520

521521
@classmethod
522-
def get(cls, agent_id: Text, api_key: Optional[Text] = None) -> Agent:
523-
"""Retrieve an agent by its ID.
522+
def get(cls, agent_id: Optional[Text] = None, name: Optional[Text] = None, api_key: Optional[Text] = None) -> Agent:
523+
"""Retrieve an agent by its ID or name.
524524
525525
Args:
526-
agent_id (Text): ID of the agent to retrieve.
526+
agent_id (Optional[Text], optional): ID of the agent to retrieve.
527+
name (Optional[Text], optional): Name of the agent to retrieve.
527528
api_key (Optional[Text], optional): API key for authentication.
528529
Defaults to None, using the configured TEAM_API_KEY.
529530
@@ -532,14 +533,23 @@ def get(cls, agent_id: Text, api_key: Optional[Text] = None) -> Agent:
532533
533534
Raises:
534535
Exception: If the agent cannot be retrieved or doesn't exist.
536+
ValueError: If neither agent_id nor name is provided, or if both are provided.
535537
"""
536538
from aixplain.factories.agent_factory.utils import build_agent
537539

538-
url = urljoin(config.BACKEND_URL, f"sdk/agents/{agent_id}")
540+
# Validate that exactly one parameter is provided
541+
if not (agent_id or name) or (agent_id and name):
542+
raise ValueError("Must provide exactly one of 'agent_id' or 'name'")
543+
544+
# Construct URL based on parameter type
545+
if agent_id:
546+
url = urljoin(config.BACKEND_URL, f"sdk/agents/{agent_id}")
547+
else: # name is provided
548+
url = urljoin(config.BACKEND_URL, f"sdk/agents/by-name/{name}")
539549

540550
api_key = api_key if api_key is not None else config.TEAM_API_KEY
541551
headers = {"x-api-key": api_key, "Content-Type": "application/json"}
542-
logging.info(f"Start service for GET Agent - {url} - {headers}")
552+
logging.info(f"Start service for GET Agent - {url} - {headers}")
543553
r = _request_with_retry("get", url, headers=headers)
544554
resp = r.json()
545555
if 200 <= r.status_code < 300:

aixplain/factories/agent_factory/utils.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def build_llm(payload: Dict, api_key: Text = config.TEAM_API_KEY) -> LLM:
140140
for tool in payload["tools"]:
141141
if tool["type"] == "llm" and tool["description"] == "main":
142142

143-
llm = get_llm_instance(payload["llmId"], api_key=api_key)
143+
llm = get_llm_instance(payload["llmId"], api_key=api_key, use_cache=True)
144144
# Set parameters from the tool
145145
if "parameters" in tool:
146146
# Apply all parameters directly to the LLM properties
@@ -191,18 +191,34 @@ def build_agent(payload: Dict, tools: List[Tool] = None, api_key: Text = config.
191191
payload_tools = tools
192192
if payload_tools is None:
193193
payload_tools = []
194-
for tool in tools_dict:
194+
# Use parallel tool building with ThreadPoolExecutor for better performance
195+
from concurrent.futures import ThreadPoolExecutor, as_completed
196+
197+
def build_tool_safe(tool_data):
198+
"""Build a single tool with error handling"""
195199
try:
196-
payload_tools.append(build_tool(tool))
200+
return build_tool(tool_data)
197201
except (ValueError, AssertionError) as e:
198202
logging.warning(str(e))
199-
continue
203+
return None
200204
except Exception:
201205
logging.warning(
202-
f"Tool {tool['assetId']} is not available. Make sure it exists or you have access to it. "
206+
f"Tool {tool_data['assetId']} is not available. Make sure it exists or you have access to it. "
203207
"If you think this is an error, please contact the administrators."
204208
)
205-
continue
209+
return None
210+
211+
# Build all tools in parallel (only if there are tools to build)
212+
if len(tools_dict) > 0:
213+
with ThreadPoolExecutor(max_workers=min(len(tools_dict), 10)) as executor:
214+
# Submit all tool build tasks
215+
future_to_tool = {executor.submit(build_tool_safe, tool): tool for tool in tools_dict}
216+
217+
# Collect results as they complete
218+
for future in as_completed(future_to_tool):
219+
tool_result = future.result()
220+
if tool_result is not None:
221+
payload_tools.append(tool_result)
206222

207223
llm = build_llm(payload, api_key)
208224

aixplain/factories/team_agent_factory/__init__.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -334,14 +334,15 @@ def list(cls) -> Dict:
334334
raise Exception(error_msg)
335335

336336
@classmethod
337-
def get(cls, agent_id: Text, api_key: Optional[Text] = None) -> TeamAgent:
338-
"""Retrieve a team agent by its ID.
337+
def get(cls, agent_id: Optional[Text] = None, name: Optional[Text] = None, api_key: Optional[Text] = None) -> TeamAgent:
338+
"""Retrieve a team agent by its ID or name.
339339
340340
This method fetches a specific team agent from the platform using its
341-
unique identifier.
341+
unique identifier or name.
342342
343343
Args:
344-
agent_id (Text): Unique identifier of the team agent to retrieve.
344+
agent_id (Optional[Text], optional): Unique identifier of the team agent to retrieve.
345+
name (Optional[Text], optional): Name of the team agent to retrieve.
345346
api_key (Optional[Text], optional): API key for authentication.
346347
Defaults to None, using the configured TEAM_API_KEY.
347348
@@ -350,15 +351,25 @@ def get(cls, agent_id: Text, api_key: Optional[Text] = None) -> TeamAgent:
350351
351352
Raises:
352353
Exception: If:
353-
- Team agent ID is invalid
354+
- Team agent ID/name is invalid
354355
- Authentication fails
355356
- Service is unavailable
356357
- Other API errors occur
358+
ValueError: If neither agent_id nor name is provided, or if both are provided.
357359
"""
358-
url = urljoin(config.BACKEND_URL, f"sdk/agent-communities/{agent_id}")
360+
# Validate that exactly one parameter is provided
361+
if not (agent_id or name) or (agent_id and name):
362+
raise ValueError("Must provide exactly one of 'agent_id' or 'name'")
363+
364+
# Construct URL based on parameter type
365+
if agent_id:
366+
url = urljoin(config.BACKEND_URL, f"sdk/agent-communities/{agent_id}")
367+
else: # name is provided
368+
url = urljoin(config.BACKEND_URL, f"sdk/agent-communities/by-name/{name}")
369+
359370
api_key = api_key if api_key is not None else config.TEAM_API_KEY
360371
headers = {"x-api-key": api_key, "Content-Type": "application/json"}
361-
logging.info(f"Start service for GET Team Agent - {url} - {headers}")
372+
logging.info(f"Start service for GET Team Agent - {url} - {headers}")
362373
try:
363374
r = _request_with_retry("get", url, headers=headers)
364375
resp = r.json()

aixplain/factories/team_agent_factory/utils.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,31 @@ def build_team_agent(payload: Dict, agents: List[Agent] = None, api_key: Text =
5959
payload_agents = agents
6060
if payload_agents is None:
6161
payload_agents = []
62-
for i, agent in enumerate(agents_dict):
62+
# Use parallel agent fetching with ThreadPoolExecutor for better performance
63+
from concurrent.futures import ThreadPoolExecutor, as_completed
64+
65+
def fetch_agent(agent_data):
66+
"""Fetch a single agent by ID with error handling"""
6367
try:
64-
payload_agents.append(AgentFactory.get(agent["assetId"]))
65-
except Exception:
68+
return AgentFactory.get(agent_data["assetId"])
69+
except Exception as e:
6670
logging.warning(
67-
f"Agent {agent['assetId']} not found. Make sure it exists or you have access to it. "
68-
"If you think this is an error, please contact the administrators."
71+
f"Agent {agent_data['assetId']} not found. Make sure it exists or you have access to it. "
72+
"If you think this is an error, please contact the administrators. Error: {e}"
6973
)
70-
continue
74+
return None
75+
76+
# Fetch all agents in parallel (only if there are agents to fetch)
77+
if len(agents_dict) > 0:
78+
with ThreadPoolExecutor(max_workers=min(len(agents_dict), 10)) as executor:
79+
# Submit all agent fetch tasks
80+
future_to_agent = {executor.submit(fetch_agent, agent): agent for agent in agents_dict}
81+
82+
# Collect results as they complete
83+
for future in as_completed(future_to_agent):
84+
agent_result = future.result()
85+
if agent_result is not None:
86+
payload_agents.append(agent_result)
7187

7288
# Ensure custom classes are instantiated: for compatibility with backend return format
7389
inspectors = []
@@ -90,6 +106,15 @@ def build_team_agent(payload: Dict, agents: List[Agent] = None, api_key: Text =
90106
# Get LLMs from tools if present
91107
supervisor_llm = None
92108
mentalist_llm = None
109+
110+
# Cache for models to avoid duplicate fetching of the same model ID
111+
model_cache = {}
112+
113+
def get_cached_model(model_id: str) -> any:
114+
"""Get model from cache or fetch if not cached"""
115+
if model_id not in model_cache:
116+
model_cache[model_id] = ModelFactory.get(model_id, api_key=api_key, use_cache=True)
117+
return model_cache[model_id]
93118

94119
# First check if we have direct LLM objects in the payload
95120
if "supervisor_llm" in payload:
@@ -100,14 +125,8 @@ def build_team_agent(payload: Dict, agents: List[Agent] = None, api_key: Text =
100125
elif "tools" in payload:
101126
for tool in payload["tools"]:
102127
if tool["type"] == "llm":
103-
try:
104-
llm = ModelFactory.get(payload["llmId"], api_key=api_key)
105-
except Exception:
106-
logging.warning(
107-
f"LLM {payload['llmId']} not found. Make sure it exists or you have access to it. "
108-
"If you think this is an error, please contact the administrators."
109-
)
110-
continue
128+
# Use cached model fetching to avoid duplicate API calls
129+
llm = get_cached_model(payload["llmId"])
111130
# Set parameters from the tool
112131
if "parameters" in tool:
113132
# Apply all parameters directly to the LLM properties
@@ -258,7 +277,7 @@ def build_team_agent_from_yaml(yaml_code: str, llm_id: str, api_key: str, team_i
258277
team_name = system_data.get("name", "")
259278
team_description = system_data.get("description", "")
260279
team_instructions = system_data.get("instructions", "")
261-
llm = ModelFactory.get(llm_id)
280+
llm = ModelFactory.get(llm_id, use_cache=True)
262281
# Create agent mapping by name for easier task assignment
263282
agents_mapping = {}
264283
agent_objs = []

aixplain/modules/agent/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def _validate(self) -> None:
175175
re.match(r"^[a-zA-Z0-9 \-\(\)]*$", self.name) is not None
176176
), "Agent Creation Error: Agent name contains invalid characters. Only alphanumeric characters, spaces, hyphens, and brackets are allowed."
177177

178-
llm = get_llm_instance(self.llm_id, api_key=self.api_key)
178+
llm = get_llm_instance(self.llm_id, api_key=self.api_key, use_cache=True)
179179

180180
assert llm.function == Function.TEXT_GENERATION, "Large Language Model must be a text generation model."
181181

aixplain/modules/agent/tool/model_tool.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def _get_model(self, model_id: Text = None):
247247
from aixplain.factories.model_factory import ModelFactory
248248

249249
model_id = model_id or self.model
250-
return ModelFactory.get(model_id, api_key=self.api_key)
250+
return ModelFactory.get(model_id, api_key=self.api_key, use_cache=True)
251251

252252
def validate_parameters(self, received_parameters: Optional[List[Dict]] = None) -> Optional[List[Dict]]:
253253
"""Validates and formats the parameters for the tool.

aixplain/modules/team_agent/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -693,7 +693,7 @@ def _validate(self) -> None:
693693
), "Team Agent Creation Error: Team name contains invalid characters. Only alphanumeric characters, spaces, hyphens, and brackets are allowed."
694694

695695
try:
696-
llm = get_llm_instance(self.llm_id)
696+
llm = get_llm_instance(self.llm_id, use_cache=True)
697697
assert llm.function == Function.TEXT_GENERATION, "Large Language Model must be a text generation model."
698698
except Exception:
699699
raise Exception(f"Large Language Model with ID '{self.llm_id}' not found.")

aixplain/utils/llm_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
def get_llm_instance(
77
llm_id: Text,
88
api_key: Optional[Text] = None,
9+
use_cache: bool = True,
910
) -> LLM:
1011
"""Get an LLM instance with specific configuration.
1112
1213
Args:
1314
llm_id (Text): ID of the LLM model to use.
1415
api_key (Optional[Text], optional): API key to use. Defaults to None.
16+
use_cache (bool, optional): Whether to use caching for model retrieval. Defaults to True.
1517
1618
Returns:
1719
LLM: Configured LLM instance.
@@ -20,7 +22,7 @@ def get_llm_instance(
2022
Exception: If the LLM model with the given ID is not found.
2123
"""
2224
try:
23-
llm = ModelFactory.get(llm_id, api_key=api_key)
25+
llm = ModelFactory.get(llm_id, api_key=api_key, use_cache=use_cache)
2426
return llm
2527
except Exception:
2628
raise Exception(f"Large Language Model with ID '{llm_id}' not found.")

aixplain/v2/agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,10 @@ def list(cls, **kwargs: Unpack[BareListParams]) -> Page["Agent"]:
7171
return AgentFactory.list(**kwargs)
7272

7373
@classmethod
74-
def get(cls, id: str, **kwargs: Unpack[BareGetParams]) -> "Agent":
74+
def get(cls, id: Optional[str] = None, name: Optional[str] = None, **kwargs: Unpack[BareGetParams]) -> "Agent":
7575
from aixplain.factories import AgentFactory
7676

77-
return AgentFactory.get(agent_id=id)
77+
return AgentFactory.get(agent_id=id, name=name)
7878

7979
@classmethod
8080
def create(cls, *args, **kwargs: Unpack[AgentCreateParams]) -> "Agent":

0 commit comments

Comments
 (0)