Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python-sdk/exospherehost/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version = "0.0.7b9"
version = "0.0.7b10"
20 changes: 16 additions & 4 deletions python-sdk/exospherehost/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,11 @@ async def _get_secrets(self, state_id: str) -> Dict[str, str]:
logger.error(f"Failed to get secrets for state {state_id}: {res}")
return {}

return res
if "secrets" in res:
return res["secrets"]
else:
logger.error(f"'secrets' not found in response for state {state_id}")
return {}
Comment thread
NiveditJain marked this conversation as resolved.

Comment thread
NiveditJain marked this conversation as resolved.
def _validate_nodes(self):
"""
Expand Down Expand Up @@ -352,6 +356,12 @@ def _validate_nodes(self):
if len(errors) > 0:
raise ValueError("Following errors while validating nodes: " + "\n".join(errors))

def _need_secrets(self, node: type[BaseNode]) -> bool:
"""
Check if the node needs secrets.
"""
return len(node.Secrets.model_fields.keys()) > 0
Comment thread
NiveditJain marked this conversation as resolved.

Comment thread
NiveditJain marked this conversation as resolved.
async def _worker(self, idx: int):
"""
Worker task that processes states from the queue.
Expand All @@ -369,10 +379,12 @@ async def _worker(self, idx: int):
node = self._node_mapping[state["node_name"]]
logger.info(f"Executing state {state['state_id']} for node {node.__name__}")

secrets = await self._get_secrets(state["state_id"])
logger.info(f"Got secrets for state {state['state_id']} for node {node.__name__}")
secrets = {}
if self._need_secrets(node):
secrets = await self._get_secrets(state["state_id"])
logger.info(f"Got secrets for state {state['state_id']} for node {node.__name__}")

outputs = await node()._execute(node.Inputs(**state["inputs"]), node.Secrets(**secrets["secrets"])) # type: ignore
outputs = await node()._execute(node.Inputs(**state["inputs"]), node.Secrets(**secrets))
logger.info(f"Got outputs for state {state['state_id']} for node {node.__name__}")

if outputs is None:
Expand Down
10 changes: 5 additions & 5 deletions python-sdk/tests/test_runtime_comprehensive.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ async def test_worker_successful_execution(self, runtime_config):
with patch('exospherehost.runtime.Runtime._get_secrets') as mock_get_secrets, \
patch('exospherehost.runtime.Runtime._notify_executed') as mock_notify_executed:

mock_get_secrets.return_value = {"secrets": {"api_key": "test_key"}}
mock_get_secrets.return_value = {"api_key": "test_key"}
mock_notify_executed.return_value = None
Comment thread
NiveditJain marked this conversation as resolved.

runtime = Runtime(**runtime_config)
Expand Down Expand Up @@ -327,7 +327,7 @@ async def test_worker_with_list_output(self, runtime_config):
with patch('exospherehost.runtime.Runtime._get_secrets') as mock_get_secrets, \
patch('exospherehost.runtime.Runtime._notify_executed') as mock_notify_executed:

mock_get_secrets.return_value = {"secrets": {"api_key": "test_key"}}
mock_get_secrets.return_value = {"api_key": "test_key"}
mock_notify_executed.return_value = None
Comment thread
NiveditJain marked this conversation as resolved.

runtime = Runtime(**runtime_config)
Expand Down Expand Up @@ -362,7 +362,7 @@ async def test_worker_with_none_output(self, runtime_config):
with patch('exospherehost.runtime.Runtime._get_secrets') as mock_get_secrets, \
patch('exospherehost.runtime.Runtime._notify_executed') as mock_notify_executed:

mock_get_secrets.return_value = {"secrets": {"api_key": "test_key"}}
mock_get_secrets.return_value = {"api_key": "test_key"}
mock_notify_executed.return_value = None
Comment thread
NiveditJain marked this conversation as resolved.

runtime = Runtime(**runtime_config)
Expand Down Expand Up @@ -394,7 +394,7 @@ async def test_worker_execution_error(self, runtime_config):
with patch('exospherehost.runtime.Runtime._get_secrets') as mock_get_secrets, \
patch('exospherehost.runtime.Runtime._notify_errored') as mock_notify_errored:

mock_get_secrets.return_value = {"secrets": {"api_key": "test_key"}}
mock_get_secrets.return_value = {"api_key": "test_key"}
mock_notify_errored.return_value = None

Comment thread
NiveditJain marked this conversation as resolved.
runtime = Runtime(**runtime_config)
Expand Down Expand Up @@ -511,7 +511,7 @@ async def test_get_secrets_success(self, runtime_config):
runtime = Runtime(**runtime_config)
result = await runtime._get_secrets("test_state_1")

assert result == {"secrets": {"api_key": "secret_key"}}
assert result == {"api_key": "secret_key"}

@pytest.mark.asyncio
async def test_get_secrets_failure(self, runtime_config):
Expand Down
Loading