Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions ee/hogai/assistant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@
from posthog.models import Team, User
from posthog.sync import database_sync_to_async

from ee.hogai.graph.base import BaseAssistantNode
from ee.hogai.graph.graph import AssistantCompiledStateGraph
from ee.hogai.graph.base import AssistantCompiledStateGraph, BaseAssistantNode
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, AssistantCompiledStateGraph is gone on my loop-executor-ui-full branch, so there will be conflicts (I can fix them once we get there)

from ee.hogai.utils.exceptions import GenerationCanceled
from ee.hogai.utils.helpers import (
extract_content_from_ai_message,
Expand Down
4 changes: 2 additions & 2 deletions ee/hogai/assistant/insights_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ee.hogai.assistant.base import BaseAssistant
from ee.hogai.graph import FunnelGeneratorNode, RetentionGeneratorNode, SQLGeneratorNode, TrendsGeneratorNode
from ee.hogai.graph.base import BaseAssistantNode
from ee.hogai.graph.graph import InsightsAssistantGraph
from ee.hogai.graph.insights_graph.graph import InsightsGraph
from ee.hogai.graph.query_executor.nodes import QueryExecutorNode
from ee.hogai.graph.taxonomy.types import TaxonomyNodeName
from ee.hogai.utils.state import GraphValueUpdateTuple, validate_value_update
Expand Down Expand Up @@ -56,7 +56,7 @@ def __init__(
conversation,
new_message=new_message,
user=user,
graph=InsightsAssistantGraph(team, user).compile_full_graph(),
graph=InsightsGraph(team, user).compile_full_graph(),
state_type=AssistantState,
partial_state_type=PartialAssistantState,
mode=AssistantMode.INSIGHTS_TOOL,
Expand Down
12 changes: 1 addition & 11 deletions ee/hogai/eval/ci/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# We want the PostHog set_up_evals fixture here
from ee.hogai.eval.conftest import set_up_evals # noqa: F401
from ee.hogai.eval.scorers import PlanAndQueryOutput
from ee.hogai.graph.graph import AssistantGraph, InsightsAssistantGraph
from ee.hogai.graph.graph import AssistantGraph
from ee.hogai.utils.types import AssistantNodeName, AssistantState
from ee.models.assistant import Conversation, CoreMemory

Expand All @@ -34,27 +34,17 @@
@pytest.fixture
def call_root_for_insight_generation(demo_org_team_user):
# This graph structure will first get a plan, then generate the SQL query.

insights_subgraph = (
# Insights subgraph without query execution, so we only create the queries
InsightsAssistantGraph(demo_org_team_user[1], demo_org_team_user[2])
.add_query_creation_flow(next_node=AssistantNodeName.END)
.compile()
)
graph = (
AssistantGraph(demo_org_team_user[1], demo_org_team_user[2])
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
.add_root(
path_map={
"insights": AssistantNodeName.INSIGHTS_SUBGRAPH,
"insights_search": AssistantNodeName.INSIGHTS_SEARCH,
"root": AssistantNodeName.ROOT,
"search_documentation": AssistantNodeName.END,
"end": AssistantNodeName.END,
}
)
.add_node(AssistantNodeName.INSIGHTS_SUBGRAPH, insights_subgraph)
.add_edge(AssistantNodeName.INSIGHTS_SUBGRAPH, AssistantNodeName.END)
.add_insights_search()
# TRICKY: We need to set a checkpointer here because async tests create a new event loop.
.compile(checkpointer=DjangoCheckpointer())
Expand Down
5 changes: 3 additions & 2 deletions ee/hogai/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .deep_research.graph import DeepResearchAssistantGraph
from .funnels.nodes import FunnelGeneratorNode
from .graph import AssistantGraph, InsightsAssistantGraph
from .graph import AssistantGraph
from .inkeep_docs.nodes import InkeepDocsNode
from .insights.nodes import InsightSearchNode
from .insights_graph.graph import InsightsGraph
from .memory.nodes import MemoryInitializerNode
from .query_executor.nodes import QueryExecutorNode
from .query_planner.nodes import QueryPlannerNode
Expand All @@ -27,7 +28,7 @@
"QueryPlannerNode",
"TrendsGeneratorNode",
"AssistantGraph",
"InsightsAssistantGraph",
"InsightsGraph",
"InsightSearchNode",
"DeepResearchAssistantGraph",
]
10 changes: 10 additions & 0 deletions ee/hogai/graph/base/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from .graph import AssistantCompiledStateGraph, BaseAssistantGraph, global_checkpointer
from .node import AssistantNode, BaseAssistantNode

__all__ = [
"BaseAssistantNode",
"AssistantNode",
"BaseAssistantGraph",
"AssistantCompiledStateGraph",
"global_checkpointer",
]
86 changes: 86 additions & 0 deletions ee/hogai/graph/base/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from collections.abc import Callable, Coroutine
from typing import Any, Generic, Literal, Protocol, runtime_checkable

from langgraph.graph.state import CompiledStateGraph, StateGraph

from posthog.schema import ReasoningMessage

from posthog.models import Team, User

from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
from ee.hogai.utils.types import AssistantNodeName, StateType
from ee.hogai.utils.types.base import BaseState
from ee.hogai.utils.types.composed import MaxNodeName

# Base checkpointer for all graphs
global_checkpointer = DjangoCheckpointer()


# Type alias for async reasoning message function, takes a state and an optional default message content and returns an optional reasoning message
GetReasoningMessageAfunc = Callable[[BaseState, str | None], Coroutine[Any, Any, ReasoningMessage | None]]
GetReasoningMessageMapType = dict[MaxNodeName, GetReasoningMessageAfunc]


# Protocol to check if a node has a reasoning message function at runtime
@runtime_checkable
class HasReasoningMessage(Protocol):
get_reasoning_message: GetReasoningMessageAfunc


class AssistantCompiledStateGraph(CompiledStateGraph):
"""Wrapper around CompiledStateGraph that preserves reasoning message information.
Note: This uses __dict__ copying as a workaround since CompiledStateGraph
doesn't support standard inheritance. This is brittle and may break with
library updates.
"""

def __init__(
self, compiled_graph: CompiledStateGraph, aget_reasoning_message_by_node_name: GetReasoningMessageMapType
):
# Copy the internal state from the compiled graph without calling super().__init__
# This is a workaround since CompiledStateGraph doesn't support standard inheritance
self.__dict__.update(compiled_graph.__dict__)
self.aget_reasoning_message_by_node_name = aget_reasoning_message_by_node_name


class BaseAssistantGraph(Generic[StateType]):
_team: Team
_user: User
_graph: StateGraph
aget_reasoning_message_by_node_name: GetReasoningMessageMapType

def __init__(self, team: Team, user: User, state_type: type[StateType]):
self._team = team
self._user = user
self._graph = StateGraph(state_type)
self._has_start_node = False
self.aget_reasoning_message_by_node_name = {}

def add_edge(self, from_node: MaxNodeName, to_node: MaxNodeName):
if from_node == AssistantNodeName.START:
self._has_start_node = True
self._graph.add_edge(from_node, to_node)
return self

def add_node(self, node: MaxNodeName, action: Any):
self._graph.add_node(node, action)
if isinstance(action, HasReasoningMessage):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remembered that isinstance on Protocol(s) with runtime_checkable was pretty slow but maybe this changed with recent python releases, cannot find anything recent, just wanted to flag if we could've have done this on some magic variables instead

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't know about that. ChatGPT confirms it, but it's still nanoseconds, so it's okay for a single check. I don't see a reason not to use it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a great memory so maybe ChatGPT is better lol, but I think it was related to accessing @property so maybe it's okay

self.aget_reasoning_message_by_node_name[node] = action.get_reasoning_message
return self

def add_subgraph(self, node_name: MaxNodeName, subgraph: AssistantCompiledStateGraph):
self._graph.add_node(node_name, subgraph)
self.aget_reasoning_message_by_node_name.update(subgraph.aget_reasoning_message_by_node_name)
return self

def compile(self, checkpointer: DjangoCheckpointer | None | Literal[False] = None):
if not self._has_start_node:
raise ValueError("Start node not added to the graph")
# TRICKY: We check `is not None` because False has a special meaning of "no checkpointer", which we want to pass on
compiled_graph = self._graph.compile(
checkpointer=checkpointer if checkpointer is not None else global_checkpointer
)
return AssistantCompiledStateGraph(
compiled_graph, aget_reasoning_message_by_node_name=self.aget_reasoning_message_by_node_name
)
3 changes: 1 addition & 2 deletions ee/hogai/graph/base.py → ee/hogai/graph/base/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@

from posthog.schema import AssistantMessage, AssistantToolCall, HumanMessage, ReasoningMessage

from posthog.models import Team
from posthog.models.user import User
from posthog.models import Team, User
from posthog.sync import database_sync_to_async

from ee.hogai.context import AssistantContextManager
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from langchain_core.runnables import RunnableLambda
from langgraph.checkpoint.memory import InMemorySaver

from ee.hogai.graph.graph import BaseAssistantGraph
from ee.hogai.graph.base.graph import BaseAssistantGraph
from ee.hogai.utils.types import AssistantNodeName, AssistantState, PartialAssistantState


Expand Down
12 changes: 11 additions & 1 deletion ee/hogai/graph/deep_research/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
from posthog.models.user import User

from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
from ee.hogai.graph.base import BaseAssistantGraph
from ee.hogai.graph.deep_research.notebook.nodes import DeepResearchNotebookPlanningNode
from ee.hogai.graph.deep_research.onboarding.nodes import DeepResearchOnboardingNode
from ee.hogai.graph.deep_research.planner.nodes import DeepResearchPlannerNode, DeepResearchPlannerToolsNode
from ee.hogai.graph.deep_research.report.nodes import DeepResearchReportNode
from ee.hogai.graph.deep_research.task_executor.nodes import DeepResearchTaskExecutorNode
from ee.hogai.graph.deep_research.types import DeepResearchNodeName, DeepResearchState
from ee.hogai.graph.graph import BaseAssistantGraph
from ee.hogai.graph.title_generator.nodes import TitleGeneratorNode


class DeepResearchAssistantGraph(BaseAssistantGraph[DeepResearchState]):
Expand Down Expand Up @@ -80,6 +81,15 @@ def add_report_node(self, next_node: DeepResearchNodeName = DeepResearchNodeName
builder.add_edge(DeepResearchNodeName.REPORT, next_node)
return self

def add_title_generator(self, end_node: DeepResearchNodeName = DeepResearchNodeName.END):
self._has_start_node = True

title_generator = TitleGeneratorNode(self._team, self._user)
self._graph.add_node(DeepResearchNodeName.TITLE_GENERATOR, title_generator)
self._graph.add_edge(DeepResearchNodeName.START, DeepResearchNodeName.TITLE_GENERATOR)
self._graph.add_edge(DeepResearchNodeName.TITLE_GENERATOR, end_node)
return self

def compile_full_graph(self, checkpointer: DjangoCheckpointer | None = None):
return (
self.add_onboarding_node()
Expand Down
8 changes: 4 additions & 4 deletions ee/hogai/graph/deep_research/task_executor/test/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ async def test_arun_handles_empty_task_list(self, mock_logger):

class TestTaskExecutorInsightsExecution(TestTaskExecutorNode):
@patch("ee.hogai.graph.deep_research.task_executor.nodes.DeepResearchTaskExecutorNode._write_message")
@patch("ee.hogai.graph.deep_research.task_executor.nodes.InsightsAssistantGraph")
@patch("ee.hogai.graph.deep_research.task_executor.nodes.InsightsGraph")
async def test_execute_task_with_insights_successful(self, mock_insights_graph_class, mock_write_message):
"""Test successful task execution through insights pipeline."""
task = self._create_task_execution_item(task_id="task_1")
Expand Down Expand Up @@ -230,7 +230,7 @@ async def mock_astream(*args, **kwargs):
self.assertEqual(len(result.artifacts), 1)

@patch("ee.hogai.graph.deep_research.task_executor.nodes.DeepResearchTaskExecutorNode._write_message")
@patch("ee.hogai.graph.deep_research.task_executor.nodes.InsightsAssistantGraph")
@patch("ee.hogai.graph.deep_research.task_executor.nodes.InsightsGraph")
async def test_execute_task_with_insights_no_artifacts(self, mock_insights_graph_class, mock_write_message):
"""Test task execution that produces no artifacts."""
task = self._create_task_execution_item(task_id="task_1")
Expand Down Expand Up @@ -268,7 +268,7 @@ async def mock_astream(*args, **kwargs):

@patch("ee.hogai.graph.deep_research.task_executor.nodes.capture_exception")
@patch("ee.hogai.graph.deep_research.task_executor.nodes.DeepResearchTaskExecutorNode._write_message")
@patch("ee.hogai.graph.deep_research.task_executor.nodes.InsightsAssistantGraph")
@patch("ee.hogai.graph.deep_research.task_executor.nodes.InsightsGraph")
async def test_execute_task_with_exception(self, mock_insights_graph_class, mock_write_message, mock_capture):
"""Test task execution that encounters an exception."""
task = self._create_task_execution_item(task_id="task_1")
Expand Down Expand Up @@ -343,7 +343,7 @@ async def mock_execute_task3(input_dict):

class TestReasoningCallback(TestTaskExecutorNode):
@patch("ee.hogai.graph.deep_research.task_executor.nodes.DeepResearchTaskExecutorNode._write_message")
@patch("ee.hogai.graph.deep_research.task_executor.nodes.InsightsAssistantGraph")
@patch("ee.hogai.graph.deep_research.task_executor.nodes.InsightsGraph")
async def test_reasoning_messages_sent_during_execution(self, mock_insights_graph_class, mock_write_message):
"""Test that reasoning messages are properly sent during task execution."""
task = self._create_task_execution_item(task_id="task_1")
Expand Down
1 change: 1 addition & 0 deletions ee/hogai/graph/deep_research/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,4 @@ class DeepResearchNodeName(StrEnum):
PLANNER_TOOLS = "planner_tools"
TASK_EXECUTOR = "task_executor"
REPORT = "report"
TITLE_GENERATOR = "title_generator"
Loading
Loading