Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 63 additions & 6 deletions python-sdk/exospherehost/runtime.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,46 @@
import asyncio
import os
import logging
import traceback

from asyncio import Queue, sleep
from typing import List, Dict

from pydantic import BaseModel
from .node.BaseNode import BaseNode
from aiohttp import ClientSession
from logging import getLogger

logger = getLogger(__name__)
logger = logging.getLogger(__name__)

def _setup_default_logging():
"""
Setup default logging only if no handlers are configured.
Respects user's existing logging configuration.
"""
root_logger = logging.getLogger()

# Don't interfere if user has already configured logging
if root_logger.handlers:
return

# Allow users to disable default logging
if os.environ.get('EXOSPHERE_DISABLE_DEFAULT_LOGGING'):
return

# Get log level from environment or default to INFO
log_level_name = os.environ.get('EXOSPHERE_LOG_LEVEL', 'INFO').upper()
log_level = getattr(logging, log_level_name, logging.INFO)

# Setup basic configuration with clean formatting
logging.basicConfig(
level=log_level,
format='%(asctime)s | %(levelname)s | %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)

# Log that we're using default configuration
logger = logging.getLogger(__name__)
logger.debug(f"ExosphereHost: Using default logging configuration (level: {log_level_name})")

Comment thread
NiveditJain marked this conversation as resolved.

class Runtime:
"""
Expand Down Expand Up @@ -48,6 +80,9 @@ class Runtime:
"""

def __init__(self, namespace: str, name: str, nodes: List[type[BaseNode]], state_manager_uri: str | None = None, key: str | None = None, batch_size: int = 16, workers: int = 4, state_manage_version: str = "v0", poll_interval: int = 1):

_setup_default_logging()

self._name = name
self._namespace = namespace
self._key = key
Expand All @@ -72,8 +107,10 @@ def _set_config_from_env(self):
Set configuration from environment variables if not provided.
"""
if self._state_manager_uri is None:
logger.info("State manager URI not provided, using environment variable EXOSPHERE_STATE_MANAGER_URI")
self._state_manager_uri = os.environ.get("EXOSPHERE_STATE_MANAGER_URI")
if self._key is None:
logger.info("API key not provided, using environment variable EXOSPHERE_API_KEY")
self._key = os.environ.get("EXOSPHERE_API_KEY")
Comment thread
NiveditJain marked this conversation as resolved.
Outdated

def _validate_runtime(self):
Expand Down Expand Up @@ -130,6 +167,7 @@ async def _register(self):
Raises:
RuntimeError: If registration fails.
"""
logger.info(f"Registering nodes: {[f"{self._namespace}/{node.__name__}" for node in self._nodes]}")
async with ClientSession() as session:
endpoint = self._get_register_endpoint()
body = {
Expand All @@ -153,8 +191,10 @@ async def _register(self):
res = await response.json()

if response.status != 200:
logger.error(f"Failed to register nodes: {res}")
raise RuntimeError(f"Failed to register nodes: {res}")

logger.info(f"Registered nodes: {[f"{self._namespace}/{node.__name__}" for node in self._nodes]}")
return res

async def _enqueue_call(self):
Expand All @@ -174,6 +214,7 @@ async def _enqueue_call(self):

if response.status != 200:
logger.error(f"Failed to enqueue states: {res}")
raise RuntimeError(f"Failed to enqueue states: {res}")

return res

Expand All @@ -189,8 +230,10 @@ async def _enqueue(self):
data = await self._enqueue_call()
for state in data.get("states", []):
await self._state_queue.put(state)
logger.info(f"Enqueued states: {len(data.get('states', []))}")
except Exception as e:
logger.error(f"Error enqueuing states: {e}")
raise

Comment thread
NiveditJain marked this conversation as resolved.
Outdated
await sleep(self._poll_interval)

Expand All @@ -212,6 +255,8 @@ async def _notify_executed(self, state_id: str, outputs: List[BaseNode.Outputs])

if response.status != 200:
logger.error(f"Failed to notify executed state {state_id}: {res}")

logger.info(f"Notified executed state {state_id} with outputs: {outputs} for node {self._node_mapping[state_id].__name__}")
Comment thread
NiveditJain marked this conversation as resolved.
Outdated

Comment thread
NiveditJain marked this conversation as resolved.
async def _notify_errored(self, state_id: str, error: str):
"""
Expand All @@ -232,6 +277,8 @@ async def _notify_errored(self, state_id: str, error: str):
if response.status != 200:
logger.error(f"Failed to notify errored state {state_id}: {res}")

logger.info(f"Notified errored state {state_id} with error: {error} for node {self._node_mapping[state_id].__name__}")
Comment thread
NiveditJain marked this conversation as resolved.
Outdated

Comment thread
NiveditJain marked this conversation as resolved.
Outdated
async def _get_secrets(self, state_id: str) -> Dict[str, str]:
"""
Get secrets for a state.
Expand Down Expand Up @@ -306,21 +353,28 @@ def _validate_nodes(self):
if len(errors) > 0:
raise ValueError("Following errors while validating nodes: " + "\n".join(errors))

async def _worker(self):
async def _worker(self, idx: int):
"""
Worker task that processes states from the queue.

Continuously fetches states from the queue, executes the corresponding node,
and notifies the state manager of the result.
"""
logger.info(f"Starting worker thread {idx} for nodes: {[f"{self._namespace}/{node.__name__}" for node in self._nodes]}")

while True:
state = await self._state_queue.get()

try:
node = self._node_mapping[state["node_name"]]
logger.info(f"Executing state {state['state_id']} for node {node.__name__}")

secrets = await self._get_secrets(state["state_id"])
outputs = await node()._execute(node.Inputs(**state["inputs"]), node.Secrets(**secrets["secrets"]))
logger.info(f"Got secrets for state {state['state_id']} for node {node.__name__}")

outputs = await node()._execute(node.Inputs(**state["inputs"]), node.Secrets(**secrets["secrets"])) # type: ignore
logger.info(f"Got outputs for state {state['state_id']} for node {node.__name__}")

Comment thread
NiveditJain marked this conversation as resolved.
Comment thread
NiveditJain marked this conversation as resolved.
if outputs is None:
outputs = []

Expand All @@ -330,6 +384,9 @@ async def _worker(self):
await self._notify_executed(state["state_id"], outputs)

except Exception as e:
logger.error(f"Error executing state {state['state_id']} for node {node.__name__}: {e}")
Comment thread
NiveditJain marked this conversation as resolved.
Outdated
logger.error(traceback.format_exc())

await self._notify_errored(state["state_id"], str(e))

Comment thread
NiveditJain marked this conversation as resolved.
self._state_queue.task_done() # type: ignore
Expand All @@ -346,7 +403,7 @@ async def _start(self):
await self._register()

poller = asyncio.create_task(self._enqueue())
worker_tasks = [asyncio.create_task(self._worker()) for _ in range(self._workers)]
worker_tasks = [asyncio.create_task(self._worker(idx)) for idx in range(self._workers)]

await asyncio.gather(poller, *worker_tasks)

Expand Down
Loading
Loading