Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion state-manager/app/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class Settings(BaseModel):
state_manager_secret: str = Field(..., description="Secret key for API authentication")
secrets_encryption_key: str = Field(..., description="Key for encrypting secrets")
trigger_workers: int = Field(default=1, description="Number of workers to run the trigger cron")
node_timeout_minutes: int = Field(default=30, gt=0, description="Timeout in minutes for nodes stuck in QUEUED status")
Comment thread
agam1092005 marked this conversation as resolved.
Outdated

@classmethod
def from_env(cls) -> "Settings":
Expand All @@ -21,7 +22,8 @@ def from_env(cls) -> "Settings":
mongo_database_name=os.getenv("MONGO_DATABASE_NAME", "exosphere-state-manager"), # type: ignore
state_manager_secret=os.getenv("STATE_MANAGER_SECRET"), # type: ignore
secrets_encryption_key=os.getenv("SECRETS_ENCRYPTION_KEY"), # type: ignore
trigger_workers=int(os.getenv("TRIGGER_WORKERS", 1)) # type: ignore
trigger_workers=os.getenv("TRIGGER_WORKERS", "1"), # type: ignore
node_timeout_minutes=os.getenv("NODE_TIMEOUT_MINUTES", "30") # type: ignore
)


Expand Down
8 changes: 6 additions & 2 deletions state-manager/app/controller/enqueue_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,21 @@


async def find_state(namespace_name: str, nodes: list[str]) -> State | None:
current_time_ms = int(time.time() * 1000)
data = await State.get_pymongo_collection().find_one_and_update(
{
"namespace_name": namespace_name,
"status": StateStatusEnum.CREATED,
"node_name": {
"$in": nodes
},
"enqueue_after": {"$lte": int(time.time() * 1000)}
"enqueue_after": {"$lte": current_time_ms}
},
{
"$set": {"status": StateStatusEnum.QUEUED}
"$set": {
"status": StateStatusEnum.QUEUED,
"queued_at": current_time_ms
}
},
return_document=ReturnDocument.AFTER
)
Comment on lines 20 to 52
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The query seems to be correct, however how will we repick timed-out states?

Expand Down
9 changes: 6 additions & 3 deletions state-manager/app/controller/register_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ async def register_nodes(namespace_name: str, body: RegisterNodesRequestModel, x
RegisteredNode.runtime_namespace: namespace_name,
RegisteredNode.inputs_schema: node_data.inputs_schema, # type: ignore
RegisteredNode.outputs_schema: node_data.outputs_schema, # type: ignore
RegisteredNode.secrets: node_data.secrets # type: ignore
RegisteredNode.secrets: node_data.secrets, # type: ignore
RegisteredNode.timeout_minutes: node_data.timeout_minutes # type: ignore
}))
logger.info(f"Updated existing node {node_data.name} in namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id)

Expand All @@ -44,7 +45,8 @@ async def register_nodes(namespace_name: str, body: RegisterNodesRequestModel, x
runtime_namespace=namespace_name,
inputs_schema=node_data.inputs_schema,
outputs_schema=node_data.outputs_schema,
secrets=node_data.secrets
secrets=node_data.secrets,
timeout_minutes=node_data.timeout_minutes
)
await new_node.insert()
logger.info(f"Created new node {node_data.name} in namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id)
Expand All @@ -54,7 +56,8 @@ async def register_nodes(namespace_name: str, body: RegisterNodesRequestModel, x
name=node_data.name,
inputs_schema=node_data.inputs_schema,
outputs_schema=node_data.outputs_schema,
secrets=node_data.secrets
secrets=node_data.secrets,
timeout_minutes=node_data.timeout_minutes
)
)

Expand Down
15 changes: 14 additions & 1 deletion state-manager/app/controller/trigger_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
from app.models.db.store import Store
from app.models.db.run import Run
from app.models.db.graph_template_model import GraphTemplate
from app.models.db.registered_node import RegisteredNode
from app.models.node_template_model import NodeTemplate
from app.models.dependent_string import DependentString
from app.config.settings import get_settings

import uuid
import time
Expand Down Expand Up @@ -91,6 +93,16 @@ async def trigger_graph(namespace_name: str, graph_name: str, body: TriggerGraph
if len(new_stores) > 0:
await Store.insert_many(new_stores)

# Get node timeout setting
registered_node = await RegisteredNode.get_by_name_and_namespace(root.node_name, root.namespace)
timeout_minutes = None
if registered_node and registered_node.timeout_minutes:
timeout_minutes = registered_node.timeout_minutes
else:
# Fall back to global setting
settings = get_settings()
timeout_minutes = settings.node_timeout_minutes

new_state = State(
node_name=root.node_name,
namespace_name=namespace_name,
Expand All @@ -101,7 +113,8 @@ async def trigger_graph(namespace_name: str, graph_name: str, body: TriggerGraph
enqueue_after=int(time.time() * 1000) + body.start_delay,
inputs=inputs,
outputs={},
error=None
error=None,
timeout_minutes=timeout_minutes
)
await new_state.insert()

Expand Down
10 changes: 10 additions & 0 deletions state-manager/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.cron import CronTrigger
from .tasks.trigger_cron import trigger_cron
from .tasks.check_node_timeout import check_node_timeout

# Define models list
DOCUMENT_MODELS = [State, GraphTemplate, RegisteredNode, Store, Run, DatabaseTriggers]
Expand Down Expand Up @@ -76,6 +77,15 @@ async def lifespan(app: FastAPI):
max_instances=1,
id="every_minute_task"
)
scheduler.add_job(
check_node_timeout,
CronTrigger.from_crontab("* * * * *"),
replace_existing=True,
misfire_grace_time=60,
coalesce=True,
max_instances=1,
id="check_node_timeout_task"
)
Comment on lines +87 to +95
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed with db queries

scheduler.start()

# main logic of the server
Expand Down
3 changes: 2 additions & 1 deletion state-manager/app/models/db/registered_node.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .base import BaseDatabaseModel
from pydantic import Field
from typing import Any
from typing import Any, Optional
from pymongo import IndexModel
from ..node_template_model import NodeTemplate

Expand All @@ -13,6 +13,7 @@ class RegisteredNode(BaseDatabaseModel):
inputs_schema: dict[str, Any] = Field(..., description="JSON schema for node inputs")
outputs_schema: dict[str, Any] = Field(..., description="JSON schema for node outputs")
secrets: list[str] = Field(default_factory=list, description="List of secrets that the node uses")
timeout_minutes: Optional[int] = Field(None, gt=0, description="Timeout in minutes for this node. Falls back to global setting if not provided")

class Settings:
indexes = [
Expand Down
10 changes: 10 additions & 0 deletions state-manager/app/models/db/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class State(BaseDatabaseModel):
retry_count: int = Field(default=0, description="Number of times the state has been retried")
fanout_id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Fanout ID of the state")
manual_retry_fanout_id: str = Field(default="", description="Fanout ID from a manual retry request, ensuring unique retries for unite nodes.")
queued_at: Optional[int] = Field(None, description="Unix time in milliseconds when state was queued")
timeout_minutes: Optional[int] = Field(None, gt=0, description="Timeout in minutes for this specific state, taken from node registration")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a better model here is to store timeout_at pre-compute when would be the timeout and store that.

queued_at: Optional[int] = Field(None, description="Unix time in milliseconds when the state was queued")
Comment thread
agam1092005 marked this conversation as resolved.
Outdated

@before_event([Insert, Replace, Save])
def _generate_fingerprint(self):
Expand Down Expand Up @@ -102,5 +105,12 @@ class Settings:
("status", 1),
],
name="run_id_status_index"
),
IndexModel(
[
("status", 1),
("queued_at", 1),
],
name="timeout_query_index"
)
Comment thread
agam1092005 marked this conversation as resolved.
]
3 changes: 2 additions & 1 deletion state-manager/app/models/register_nodes_request.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from pydantic import BaseModel, Field
from typing import Any, List
from typing import Any, List, Optional


class NodeRegistrationModel(BaseModel):
name: str = Field(..., description="Unique name of the node")
inputs_schema: dict[str, Any] = Field(..., description="JSON schema for node inputs")
outputs_schema: dict[str, Any] = Field(..., description="JSON schema for node outputs")
secrets: List[str] = Field(..., description="List of secrets that the node uses")
timeout_minutes: Optional[int] = Field(None, gt=0, description="Timeout in minutes for this node. Falls back to global setting if not provided")


class RegisterNodesRequestModel(BaseModel):
Expand Down
3 changes: 2 additions & 1 deletion state-manager/app/models/register_nodes_response.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from pydantic import BaseModel, Field
from typing import Any, List
from typing import Any, List, Optional


class RegisteredNodeModel(BaseModel):
name: str = Field(..., description="Name of the registered node")
inputs_schema: dict[str, Any] = Field(..., description="Inputs for the registered node")
outputs_schema: dict[str, Any] = Field(..., description="Outputs for the registered node")
secrets: List[str] = Field(..., description="List of secrets that the node uses")
timeout_minutes: Optional[int] = Field(None, gt=0, description="Timeout in minutes for this node. Falls back to global setting if not provided")


class RegisterNodesResponseModel(BaseModel):
Expand Down
1 change: 1 addition & 0 deletions state-manager/app/models/state_status_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class StateStatusEnum(str, Enum):
# Errored
ERRORED = 'ERRORED'
NEXT_CREATED_ERROR = 'NEXT_CREATED_ERROR'
TIMEDOUT = 'TIMEDOUT'

# Success
SUCCESS = 'SUCCESS'
Expand Down
42 changes: 42 additions & 0 deletions state-manager/app/tasks/check_node_timeout.py
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while this model of periodic jobs will work, its unnecessary as we can write a database query to figure out timeout nodes, we probably do not need to set the status timeout just from if the status is Queued and current_time > timeout_at we can figure it.

Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import time
from app.models.db.state import State
from app.models.state_status_enum import StateStatusEnum
from app.singletons.logs_manager import LogsManager
from app.config.settings import get_settings

logger = LogsManager().get_logger()


async def check_node_timeout():
try:
settings = get_settings()
current_time_ms = int(time.time() * 1000)

logger.info(f"Checking for timed out nodes at {current_time_ms}")

# Find all QUEUED states with queued_at set
queued_states = await State.find(
State.status == StateStatusEnum.QUEUED,
State.queued_at != None
).to_list()
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

states_to_timeout = []

for state in queued_states:
# Use state-specific timeout if available, otherwise fall back to global
timeout_minutes = state.timeout_minutes if state.timeout_minutes else settings.node_timeout_minutes
timeout_ms = timeout_minutes * 60 * 1000
timeout_threshold = current_time_ms - timeout_ms

if state.queued_at <= timeout_threshold:
state.status = StateStatusEnum.TIMEDOUT
state.error = f"Node execution timed out after {timeout_minutes} minutes"
states_to_timeout.append(state)

if states_to_timeout:
# Update all timed out states in bulk
await State.save_all(states_to_timeout)
logger.info(f"Marked {len(states_to_timeout)} states as TIMEDOUT")

except Exception:
logger.error("Error checking node timeout", exc_info=True)
14 changes: 13 additions & 1 deletion state-manager/app/tasks/create_next_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from app.models.db.store import Store
from app.models.dependent_string import DependentString
from app.models.node_template_model import UnitesStrategyEnum
from app.config.settings import get_settings
from json_schema_to_pydantic import create_model
from pydantic import BaseModel
from typing import Type
Expand Down Expand Up @@ -162,6 +163,16 @@ async def generate_next_state(next_state_input_model: Type[BaseModel], next_stat
current_state.identifier: current_state.id
}

# Get timeout for this node
registered_node = await get_registered_node(next_state_node_template)
timeout_minutes = None
if registered_node.timeout_minutes:
timeout_minutes = registered_node.timeout_minutes
else:
# Fall back to global setting
settings = get_settings()
timeout_minutes = settings.node_timeout_minutes

return State(
node_name=next_state_node_template.node_name,
identifier=next_state_node_template.identifier,
Expand All @@ -173,7 +184,8 @@ async def generate_next_state(next_state_input_model: Type[BaseModel], next_stat
outputs={},
does_unites=next_state_node_template.unites is not None,
run_id=current_state.run_id,
error=None
error=None,
timeout_minutes=timeout_minutes
)

current_states = await State.find(
Expand Down
Loading