|
| 1 | +from collections.abc import Callable, Coroutine |
| 2 | +from typing import Any, Generic, Literal, Protocol, runtime_checkable |
| 3 | + |
| 4 | +from langgraph.graph.state import CompiledStateGraph, StateGraph |
| 5 | + |
| 6 | +from posthog.schema import ReasoningMessage |
| 7 | + |
| 8 | +from posthog.models import Team, User |
| 9 | + |
| 10 | +from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer |
| 11 | +from ee.hogai.utils.types import AssistantNodeName, StateType |
| 12 | +from ee.hogai.utils.types.base import BaseState |
| 13 | +from ee.hogai.utils.types.composed import MaxNodeName |
| 14 | + |
| 15 | +# Base checkpointer for all graphs |
| 16 | +global_checkpointer = DjangoCheckpointer() |
| 17 | + |
| 18 | + |
| 19 | +# Type alias for async reasoning message function, takes a state and an optional default message content and returns an optional reasoning message |
| 20 | +GetReasoningMessageAfunc = Callable[[BaseState, str | None], Coroutine[Any, Any, ReasoningMessage | None]] |
| 21 | +GetReasoningMessageMapType = dict[MaxNodeName, GetReasoningMessageAfunc] |
| 22 | + |
| 23 | + |
| 24 | +# Protocol to check if a node has a reasoning message function at runtime |
| 25 | +@runtime_checkable |
| 26 | +class HasReasoningMessage(Protocol): |
| 27 | + get_reasoning_message: GetReasoningMessageAfunc |
| 28 | + |
| 29 | + |
| 30 | +class AssistantCompiledStateGraph(CompiledStateGraph): |
| 31 | + """Wrapper around CompiledStateGraph that preserves reasoning message information. |
| 32 | +
|
| 33 | + Note: This uses __dict__ copying as a workaround since CompiledStateGraph |
| 34 | + doesn't support standard inheritance. This is brittle and may break with |
| 35 | + library updates. |
| 36 | + """ |
| 37 | + |
| 38 | + def __init__( |
| 39 | + self, compiled_graph: CompiledStateGraph, aget_reasoning_message_by_node_name: GetReasoningMessageMapType |
| 40 | + ): |
| 41 | + # Copy the internal state from the compiled graph without calling super().__init__ |
| 42 | + # This is a workaround since CompiledStateGraph doesn't support standard inheritance |
| 43 | + self.__dict__.update(compiled_graph.__dict__) |
| 44 | + self.aget_reasoning_message_by_node_name = aget_reasoning_message_by_node_name |
| 45 | + |
| 46 | + |
| 47 | +class BaseAssistantGraph(Generic[StateType]): |
| 48 | + _team: Team |
| 49 | + _user: User |
| 50 | + _graph: StateGraph |
| 51 | + aget_reasoning_message_by_node_name: GetReasoningMessageMapType |
| 52 | + |
| 53 | + def __init__(self, team: Team, user: User, state_type: type[StateType]): |
| 54 | + self._team = team |
| 55 | + self._user = user |
| 56 | + self._graph = StateGraph(state_type) |
| 57 | + self._has_start_node = False |
| 58 | + self.aget_reasoning_message_by_node_name = {} |
| 59 | + |
| 60 | + def add_edge(self, from_node: MaxNodeName, to_node: MaxNodeName): |
| 61 | + if from_node == AssistantNodeName.START: |
| 62 | + self._has_start_node = True |
| 63 | + self._graph.add_edge(from_node, to_node) |
| 64 | + return self |
| 65 | + |
| 66 | + def add_node(self, node: MaxNodeName, action: Any): |
| 67 | + self._graph.add_node(node, action) |
| 68 | + if isinstance(action, HasReasoningMessage): |
| 69 | + self.aget_reasoning_message_by_node_name[node] = action.get_reasoning_message |
| 70 | + return self |
| 71 | + |
| 72 | + def add_subgraph(self, node_name: MaxNodeName, subgraph: AssistantCompiledStateGraph): |
| 73 | + self._graph.add_node(node_name, subgraph) |
| 74 | + self.aget_reasoning_message_by_node_name.update(subgraph.aget_reasoning_message_by_node_name) |
| 75 | + return self |
| 76 | + |
| 77 | + def compile(self, checkpointer: DjangoCheckpointer | None | Literal[False] = None): |
| 78 | + if not self._has_start_node: |
| 79 | + raise ValueError("Start node not added to the graph") |
| 80 | + # TRICKY: We check `is not None` because False has a special meaning of "no checkpointer", which we want to pass on |
| 81 | + compiled_graph = self._graph.compile( |
| 82 | + checkpointer=checkpointer if checkpointer is not None else global_checkpointer |
| 83 | + ) |
| 84 | + return AssistantCompiledStateGraph( |
| 85 | + compiled_graph, aget_reasoning_message_by_node_name=self.aget_reasoning_message_by_node_name |
| 86 | + ) |
0 commit comments