Skip to content

Commit 6403cd2

Browse files
authored
feat(max): convert the insights graph to a Max tool (#39893)
1 parent 8b99783 commit 6403cd2

File tree

30 files changed

+873
-634
lines changed

30 files changed

+873
-634
lines changed

ee/hogai/assistant/base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@
3232
from posthog.models import Team, User
3333
from posthog.sync import database_sync_to_async
3434

35-
from ee.hogai.graph.base import BaseAssistantNode
36-
from ee.hogai.graph.graph import AssistantCompiledStateGraph
35+
from ee.hogai.graph.base import AssistantCompiledStateGraph, BaseAssistantNode
3736
from ee.hogai.utils.exceptions import GenerationCanceled
3837
from ee.hogai.utils.helpers import (
3938
extract_content_from_ai_message,

ee/hogai/assistant/insights_assistant.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from ee.hogai.assistant.base import BaseAssistant
1919
from ee.hogai.graph import FunnelGeneratorNode, RetentionGeneratorNode, SQLGeneratorNode, TrendsGeneratorNode
2020
from ee.hogai.graph.base import BaseAssistantNode
21-
from ee.hogai.graph.graph import InsightsAssistantGraph
21+
from ee.hogai.graph.insights_graph.graph import InsightsGraph
2222
from ee.hogai.graph.query_executor.nodes import QueryExecutorNode
2323
from ee.hogai.graph.taxonomy.types import TaxonomyNodeName
2424
from ee.hogai.utils.state import GraphValueUpdateTuple, validate_value_update
@@ -56,7 +56,7 @@ def __init__(
5656
conversation,
5757
new_message=new_message,
5858
user=user,
59-
graph=InsightsAssistantGraph(team, user).compile_full_graph(),
59+
graph=InsightsGraph(team, user).compile_full_graph(),
6060
state_type=AssistantState,
6161
partial_state_type=PartialAssistantState,
6262
mode=AssistantMode.INSIGHTS_TOOL,

ee/hogai/eval/ci/conftest.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
# We want the PostHog set_up_evals fixture here
2020
from ee.hogai.eval.conftest import set_up_evals # noqa: F401
2121
from ee.hogai.eval.scorers import PlanAndQueryOutput
22-
from ee.hogai.graph.graph import AssistantGraph, InsightsAssistantGraph
22+
from ee.hogai.graph.graph import AssistantGraph
2323
from ee.hogai.utils.types import AssistantNodeName, AssistantState
2424
from ee.models.assistant import Conversation, CoreMemory
2525

@@ -34,27 +34,17 @@
3434
@pytest.fixture
3535
def call_root_for_insight_generation(demo_org_team_user):
3636
# This graph structure will first get a plan, then generate the SQL query.
37-
38-
insights_subgraph = (
39-
# Insights subgraph without query execution, so we only create the queries
40-
InsightsAssistantGraph(demo_org_team_user[1], demo_org_team_user[2])
41-
.add_query_creation_flow(next_node=AssistantNodeName.END)
42-
.compile()
43-
)
4437
graph = (
4538
AssistantGraph(demo_org_team_user[1], demo_org_team_user[2])
4639
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
4740
.add_root(
4841
path_map={
49-
"insights": AssistantNodeName.INSIGHTS_SUBGRAPH,
5042
"insights_search": AssistantNodeName.INSIGHTS_SEARCH,
5143
"root": AssistantNodeName.ROOT,
5244
"search_documentation": AssistantNodeName.END,
5345
"end": AssistantNodeName.END,
5446
}
5547
)
56-
.add_node(AssistantNodeName.INSIGHTS_SUBGRAPH, insights_subgraph)
57-
.add_edge(AssistantNodeName.INSIGHTS_SUBGRAPH, AssistantNodeName.END)
5848
.add_insights_search()
5949
# TRICKY: We need to set a checkpointer here because async tests create a new event loop.
6050
.compile(checkpointer=DjangoCheckpointer())

ee/hogai/graph/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from .deep_research.graph import DeepResearchAssistantGraph
22
from .funnels.nodes import FunnelGeneratorNode
3-
from .graph import AssistantGraph, InsightsAssistantGraph
3+
from .graph import AssistantGraph
44
from .inkeep_docs.nodes import InkeepDocsNode
55
from .insights.nodes import InsightSearchNode
6+
from .insights_graph.graph import InsightsGraph
67
from .memory.nodes import MemoryInitializerNode
78
from .query_executor.nodes import QueryExecutorNode
89
from .query_planner.nodes import QueryPlannerNode
@@ -27,7 +28,7 @@
2728
"QueryPlannerNode",
2829
"TrendsGeneratorNode",
2930
"AssistantGraph",
30-
"InsightsAssistantGraph",
31+
"InsightsGraph",
3132
"InsightSearchNode",
3233
"DeepResearchAssistantGraph",
3334
]

ee/hogai/graph/base/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from .graph import AssistantCompiledStateGraph, BaseAssistantGraph, global_checkpointer
2+
from .node import AssistantNode, BaseAssistantNode
3+
4+
__all__ = [
5+
"BaseAssistantNode",
6+
"AssistantNode",
7+
"BaseAssistantGraph",
8+
"AssistantCompiledStateGraph",
9+
"global_checkpointer",
10+
]

ee/hogai/graph/base/graph.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from collections.abc import Callable, Coroutine
2+
from typing import Any, Generic, Literal, Protocol, runtime_checkable
3+
4+
from langgraph.graph.state import CompiledStateGraph, StateGraph
5+
6+
from posthog.schema import ReasoningMessage
7+
8+
from posthog.models import Team, User
9+
10+
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
11+
from ee.hogai.utils.types import AssistantNodeName, StateType
12+
from ee.hogai.utils.types.base import BaseState
13+
from ee.hogai.utils.types.composed import MaxNodeName
14+
15+
# Base checkpointer for all graphs
16+
global_checkpointer = DjangoCheckpointer()
17+
18+
19+
# Type alias for async reasoning message function, takes a state and an optional default message content and returns an optional reasoning message
20+
GetReasoningMessageAfunc = Callable[[BaseState, str | None], Coroutine[Any, Any, ReasoningMessage | None]]
21+
GetReasoningMessageMapType = dict[MaxNodeName, GetReasoningMessageAfunc]
22+
23+
24+
# Protocol to check if a node has a reasoning message function at runtime
25+
@runtime_checkable
26+
class HasReasoningMessage(Protocol):
27+
get_reasoning_message: GetReasoningMessageAfunc
28+
29+
30+
class AssistantCompiledStateGraph(CompiledStateGraph):
31+
"""Wrapper around CompiledStateGraph that preserves reasoning message information.
32+
33+
Note: This uses __dict__ copying as a workaround since CompiledStateGraph
34+
doesn't support standard inheritance. This is brittle and may break with
35+
library updates.
36+
"""
37+
38+
def __init__(
39+
self, compiled_graph: CompiledStateGraph, aget_reasoning_message_by_node_name: GetReasoningMessageMapType
40+
):
41+
# Copy the internal state from the compiled graph without calling super().__init__
42+
# This is a workaround since CompiledStateGraph doesn't support standard inheritance
43+
self.__dict__.update(compiled_graph.__dict__)
44+
self.aget_reasoning_message_by_node_name = aget_reasoning_message_by_node_name
45+
46+
47+
class BaseAssistantGraph(Generic[StateType]):
48+
_team: Team
49+
_user: User
50+
_graph: StateGraph
51+
aget_reasoning_message_by_node_name: GetReasoningMessageMapType
52+
53+
def __init__(self, team: Team, user: User, state_type: type[StateType]):
54+
self._team = team
55+
self._user = user
56+
self._graph = StateGraph(state_type)
57+
self._has_start_node = False
58+
self.aget_reasoning_message_by_node_name = {}
59+
60+
def add_edge(self, from_node: MaxNodeName, to_node: MaxNodeName):
61+
if from_node == AssistantNodeName.START:
62+
self._has_start_node = True
63+
self._graph.add_edge(from_node, to_node)
64+
return self
65+
66+
def add_node(self, node: MaxNodeName, action: Any):
67+
self._graph.add_node(node, action)
68+
if isinstance(action, HasReasoningMessage):
69+
self.aget_reasoning_message_by_node_name[node] = action.get_reasoning_message
70+
return self
71+
72+
def add_subgraph(self, node_name: MaxNodeName, subgraph: AssistantCompiledStateGraph):
73+
self._graph.add_node(node_name, subgraph)
74+
self.aget_reasoning_message_by_node_name.update(subgraph.aget_reasoning_message_by_node_name)
75+
return self
76+
77+
def compile(self, checkpointer: DjangoCheckpointer | None | Literal[False] = None):
78+
if not self._has_start_node:
79+
raise ValueError("Start node not added to the graph")
80+
# TRICKY: We check `is not None` because False has a special meaning of "no checkpointer", which we want to pass on
81+
compiled_graph = self._graph.compile(
82+
checkpointer=checkpointer if checkpointer is not None else global_checkpointer
83+
)
84+
return AssistantCompiledStateGraph(
85+
compiled_graph, aget_reasoning_message_by_node_name=self.aget_reasoning_message_by_node_name
86+
)

ee/hogai/graph/base.py renamed to ee/hogai/graph/base/node.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111

1212
from posthog.schema import AssistantMessage, AssistantToolCall, HumanMessage, ReasoningMessage
1313

14-
from posthog.models import Team
15-
from posthog.models.user import User
14+
from posthog.models import Team, User
1615
from posthog.sync import database_sync_to_async
1716

1817
from ee.hogai.context import AssistantContextManager

ee/hogai/graph/base/test/__init__.py

Whitespace-only changes.

ee/hogai/graph/test/test_assistant_graph.py renamed to ee/hogai/graph/base/test/test_assistant_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from langchain_core.runnables import RunnableLambda
44
from langgraph.checkpoint.memory import InMemorySaver
55

6-
from ee.hogai.graph.graph import BaseAssistantGraph
6+
from ee.hogai.graph.base.graph import BaseAssistantGraph
77
from ee.hogai.utils.types import AssistantNodeName, AssistantState, PartialAssistantState
88

99

ee/hogai/graph/deep_research/graph.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44
from posthog.models.user import User
55

66
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
7+
from ee.hogai.graph.base import BaseAssistantGraph
78
from ee.hogai.graph.deep_research.notebook.nodes import DeepResearchNotebookPlanningNode
89
from ee.hogai.graph.deep_research.onboarding.nodes import DeepResearchOnboardingNode
910
from ee.hogai.graph.deep_research.planner.nodes import DeepResearchPlannerNode, DeepResearchPlannerToolsNode
1011
from ee.hogai.graph.deep_research.report.nodes import DeepResearchReportNode
1112
from ee.hogai.graph.deep_research.task_executor.nodes import DeepResearchTaskExecutorNode
1213
from ee.hogai.graph.deep_research.types import DeepResearchNodeName, DeepResearchState
13-
from ee.hogai.graph.graph import BaseAssistantGraph
14+
from ee.hogai.graph.title_generator.nodes import TitleGeneratorNode
1415

1516

1617
class DeepResearchAssistantGraph(BaseAssistantGraph[DeepResearchState]):
@@ -80,6 +81,15 @@ def add_report_node(self, next_node: DeepResearchNodeName = DeepResearchNodeName
8081
builder.add_edge(DeepResearchNodeName.REPORT, next_node)
8182
return self
8283

84+
def add_title_generator(self, end_node: DeepResearchNodeName = DeepResearchNodeName.END):
85+
self._has_start_node = True
86+
87+
title_generator = TitleGeneratorNode(self._team, self._user)
88+
self._graph.add_node(DeepResearchNodeName.TITLE_GENERATOR, title_generator)
89+
self._graph.add_edge(DeepResearchNodeName.START, DeepResearchNodeName.TITLE_GENERATOR)
90+
self._graph.add_edge(DeepResearchNodeName.TITLE_GENERATOR, end_node)
91+
return self
92+
8393
def compile_full_graph(self, checkpointer: DjangoCheckpointer | None = None):
8494
return (
8595
self.add_onboarding_node()

0 commit comments

Comments
 (0)