Skip to content

Commit 0f3499f

Browse files
committed
fix: raise SchemaGenerationException so we avoid FailureMessages
1 parent 8d78599 commit 0f3499f

File tree

2 files changed

+52
-19
lines changed

2 files changed

+52
-19
lines changed

ee/hogai/graph/root/tools/create_and_query_insight.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99

1010
from ee.hogai.context.context import AssistantContextManager
1111
from ee.hogai.graph.insights_graph.graph import InsightsGraph
12+
from ee.hogai.graph.schema_generator.nodes import SchemaGenerationException
1213
from ee.hogai.tool import MaxTool, ToolMessagesArtifact
14+
from ee.hogai.utils.prompt import format_prompt_string
1315
from ee.hogai.utils.types.base import AssistantState
1416

1517
INSIGHT_TOOL_PROMPT = """
@@ -107,14 +109,35 @@
107109
</system_reminder>
108110
""".strip()
109111

110-
INSIGHT_TOOL_FAILURE_PROMPT = """
111-
The agent has encountered an error while creating an insight.
112+
INSIGHT_TOOL_FAILURE_SYSTEM_REMINDER_PROMPT = """
112113
<system_reminder>
113114
Inform the user that you've encountered an error during the creation of the insight. Afterwards, try to generate a new insight with a different query.
114115
Terminate if the error persists.
115116
</system_reminder>
116117
""".strip()
117118

119+
INSIGHT_TOOL_HANDLED_FAILURE_PROMPT = """
120+
The agent has encountered the error while creating an insight.
121+
122+
Generated output:
123+
```
124+
{{{output}}}
125+
```
126+
127+
Error message:
128+
```
129+
{{{error_message}}}
130+
```
131+
132+
{{{system_reminder}}}
133+
""".strip()
134+
135+
136+
INSIGHT_TOOL_UNHANDLED_FAILURE_PROMPT = """
137+
The agent has encountered an unknown error while creating an insight.
138+
{{{system_reminder}}}
139+
""".strip()
140+
118141

119142
class CreateAndQueryInsightToolArgs(BaseModel):
120143
tool_call_id: Annotated[str, InjectedToolCallId, SkipJsonSchema]
@@ -145,12 +168,23 @@ async def _arun_impl(self, query_description: str, tool_call_id: str) -> tuple[s
145168
},
146169
deep=True,
147170
)
148-
dict_state = await graph.ainvoke(new_state)
171+
try:
172+
dict_state = await graph.ainvoke(new_state)
173+
except SchemaGenerationException as e:
174+
return format_prompt_string(
175+
INSIGHT_TOOL_HANDLED_FAILURE_PROMPT,
176+
output=e.llm_output,
177+
error_message=e.validation_message,
178+
system_reminder=INSIGHT_TOOL_FAILURE_SYSTEM_REMINDER_PROMPT,
179+
), None
180+
149181
updated_state = AssistantState.model_validate(dict_state)
150182
maybe_viz_message, tool_call_message = updated_state.messages[-2:]
151183

152184
if not isinstance(tool_call_message, AssistantToolCallMessage):
153-
return INSIGHT_TOOL_FAILURE_PROMPT, None
185+
return format_prompt_string(
186+
INSIGHT_TOOL_UNHANDLED_FAILURE_PROMPT, system_reminder=INSIGHT_TOOL_FAILURE_SYSTEM_REMINDER_PROMPT
187+
), None
154188

155189
# If the previous message is not a visualization message, the agent has requested human feedback.
156190
if not isinstance(maybe_viz_message, VisualizationMessage):

ee/hogai/graph/schema_generator/nodes.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
1414
from langchain_core.runnables import RunnableConfig
1515

16-
from posthog.schema import FailureMessage, VisualizationMessage
16+
from posthog.schema import VisualizationMessage
1717

1818
from posthog.models.group_type_mapping import GroupTypeMapping
1919

@@ -36,6 +36,15 @@
3636
RETRIES_ALLOWED = 2
3737

3838

39+
class SchemaGenerationException(Exception):
40+
"""An error occurred while generating a schema in the `SchemaGeneratorNode` node."""
41+
42+
def __init__(self, llm_output: str, validation_message: str):
43+
super().__init__("Failed to generate schema")
44+
self.llm_output = llm_output
45+
self.validation_message = validation_message
46+
47+
3948
class SchemaGeneratorNode(AssistantNode, Generic[Q]):
4049
INSIGHT_NAME: str
4150
"""
@@ -87,9 +96,8 @@ async def _run_with_prompt(
8796

8897
chain = generation_prompt | merger | self._model | self._parse_output
8998

90-
result: SchemaGeneratorOutput[Q] | None = None
9199
try:
92-
result = await chain.ainvoke(
100+
result: SchemaGeneratorOutput[Q] = await chain.ainvoke(
93101
{
94102
"project_datetime": self.project_now,
95103
"project_timezone": self.project_timezone,
@@ -120,18 +128,9 @@ async def _run_with_prompt(
120128
query_generation_retry_count=len(intermediate_steps) + 1,
121129
)
122130

123-
if not result:
124-
# We've got no usable result after exhausting all iteration attempts - it's failure message time
125-
return PartialAssistantState(
126-
messages=[
127-
FailureMessage(
128-
content=f"It looks like I'm having trouble generating this {self.INSIGHT_NAME} insight."
129-
)
130-
],
131-
intermediate_steps=None,
132-
plan=None,
133-
query_generation_retry_count=len(intermediate_steps) + 1,
134-
)
131+
if isinstance(e, PydanticOutputParserException):
132+
raise SchemaGenerationException(e.llm_output, e.validation_message)
133+
raise SchemaGenerationException(e.llm_output or "No input was provided.", str(e))
135134

136135
# We've got a result that either passed the quality check or we've exhausted all attempts at iterating - return
137136
return PartialAssistantState(

0 commit comments

Comments
 (0)