diff --git a/ee/hogai/graph/conversation_summarizer/nodes.py b/ee/hogai/graph/conversation_summarizer/nodes.py index b984a6649e8f7..799cbcffcd1f4 100644 --- a/ee/hogai/graph/conversation_summarizer/nodes.py +++ b/ee/hogai/graph/conversation_summarizer/nodes.py @@ -18,11 +18,7 @@ def __init__(self, team: Team, user: User): self._team = team async def summarize(self, messages: Sequence[BaseMessage]) -> str: - prompt = ( - ChatPromptTemplate.from_messages([("system", SYSTEM_PROMPT)]) - + messages - + ChatPromptTemplate.from_messages([("user", USER_PROMPT)]) - ) + prompt = self._construct_messages(messages) model = self._get_model() chain = prompt | model | StrOutputParser() | self._parse_xml_tags response: str = await chain.ainvoke({}) # Do not pass config here, so the node doesn't stream @@ -31,6 +27,13 @@ async def summarize(self, messages: Sequence[BaseMessage]) -> str: @abstractmethod def _get_model(self): ... + def _construct_messages(self, messages: Sequence[BaseMessage]): + return ( + ChatPromptTemplate.from_messages([("system", SYSTEM_PROMPT)]) + + messages + + ChatPromptTemplate.from_messages([("user", USER_PROMPT)]) + ) + def _parse_xml_tags(self, message: str) -> str: """ Extract analysis and summary tags from a message. @@ -54,7 +57,7 @@ def _parse_xml_tags(self, message: str) -> str: class AnthropicConversationSummarizer(ConversationSummarizer): def _get_model(self): return MaxChatAnthropic( - model="claude-sonnet-4-5", + model="claude-haiku-4-5", streaming=False, stream_usage=False, max_tokens=8192, @@ -62,3 +65,16 @@ def _get_model(self): user=self._user, team=self._team, ) + + def _construct_messages(self, messages: Sequence[BaseMessage]): + """Removes cache_control headers.""" + messages_without_cache: list[BaseMessage] = [] + for message in messages: + if isinstance(message.content, list): + message = message.model_copy(deep=True) + for content in message.content: + if isinstance(content, dict) and "cache_control" in content: + content.pop("cache_control") + messages_without_cache.append(message) + + return super()._construct_messages(messages_without_cache) diff --git a/ee/hogai/graph/conversation_summarizer/test/__init__.py b/ee/hogai/graph/conversation_summarizer/test/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/ee/hogai/graph/conversation_summarizer/test/test_nodes.py b/ee/hogai/graph/conversation_summarizer/test/test_nodes.py new file mode 100644 index 0000000000000..fe404e6e46b63 --- /dev/null +++ b/ee/hogai/graph/conversation_summarizer/test/test_nodes.py @@ -0,0 +1,177 @@ +from typing import Any, cast + +from posthog.test.base import BaseTest + +from langchain_core.messages import ( + AIMessage as LangchainAIMessage, + HumanMessage as LangchainHumanMessage, +) +from parameterized import parameterized + +from ee.hogai.graph.conversation_summarizer.nodes import AnthropicConversationSummarizer + + +class TestAnthropicConversationSummarizer(BaseTest): + def setUp(self): + super().setUp() + self.summarizer = AnthropicConversationSummarizer(team=self.team, user=self.user) + + @parameterized.expand( + [ + ( + "single_message_with_cache_control", + [ + LangchainHumanMessage( + content=[ + {"type": "text", "text": "Hello", "cache_control": {"type": "ephemeral"}}, + ] + ) + ], + [[{"type": "text", "text": "Hello"}]], + ), + ( + "multiple_items_with_cache_control", + [ + LangchainAIMessage( + content=[ + {"type": "text", "text": "First", "cache_control": {"type": "ephemeral"}}, + {"type": "text", "text": "Second", "cache_control": {"type": "ephemeral"}}, + ] + ) + ], + [[{"type": "text", "text": "First"}, {"type": "text", "text": "Second"}]], + ), + ( + "mixed_items_some_with_cache_control", + [ + LangchainHumanMessage( + content=[ + {"type": "text", "text": "With cache", "cache_control": {"type": "ephemeral"}}, + {"type": "text", "text": "Without cache"}, + ] + ) + ], + [[{"type": "text", "text": "With cache"}, {"type": "text", "text": "Without cache"}]], + ), + ( + "multiple_messages_with_cache_control", + [ + LangchainHumanMessage( + content=[ + {"type": "text", "text": "Message 1", "cache_control": {"type": "ephemeral"}}, + ] + ), + LangchainAIMessage( + content=[ + {"type": "text", "text": "Message 2", "cache_control": {"type": "ephemeral"}}, + ] + ), + ], + [ + [{"type": "text", "text": "Message 1"}], + [{"type": "text", "text": "Message 2"}], + ], + ), + ] + ) + def test_removes_cache_control(self, name, input_messages, expected_contents): + result = self.summarizer._construct_messages(input_messages) + + # Extract the actual messages from the prompt template + messages = result.messages[1:-1] # Skip system prompt and user prompt + + self.assertEqual(len(messages), len(expected_contents), f"Wrong number of messages in test case: {name}") + + for i, (message, expected_content) in enumerate(zip(messages, expected_contents)): + self.assertEqual( + message.content, + expected_content, + f"Message {i} content mismatch in test case: {name}", + ) + + @parameterized.expand( + [ + ( + "string_content", + [LangchainHumanMessage(content="Simple string")], + ), + ( + "empty_list_content", + [LangchainHumanMessage(content=[])], + ), + ( + "non_dict_items_in_list", + [LangchainHumanMessage(content=["string_item", {"type": "text", "text": "dict_item"}])], + ), + ] + ) + def test_handles_non_dict_content_without_errors(self, name, input_messages): + result = self.summarizer._construct_messages(input_messages) + self.assertIsNotNone(result) + + def test_original_message_not_modified(self): + original_content: list[str | dict[Any, Any]] = [ + {"type": "text", "text": "Hello", "cache_control": {"type": "ephemeral"}}, + ] + message = LangchainHumanMessage(content=original_content) + + # Store the original cache_control to verify it's not modified + content_list = cast(list[dict[str, Any]], message.content) + self.assertIn("cache_control", content_list[0]) + + self.summarizer._construct_messages([message]) + + # Verify original message still has cache_control + content_list = cast(list[dict[str, Any]], message.content) + self.assertIn("cache_control", content_list[0]) + self.assertEqual(content_list[0]["cache_control"], {"type": "ephemeral"}) + + def test_deep_copy_prevents_modification(self): + original_content: list[str | dict[Any, Any]] = [ + { + "type": "text", + "text": "Test", + "cache_control": {"type": "ephemeral"}, + "other_key": "value", + }, + ] + message = LangchainHumanMessage(content=original_content) + + content_list = cast(list[dict[str, Any]], message.content) + original_keys = set(content_list[0].keys()) + + self.summarizer._construct_messages([message]) + + # Verify original message structure unchanged + content_list = cast(list[dict[str, Any]], message.content) + self.assertEqual(set(content_list[0].keys()), original_keys) + self.assertIn("cache_control", content_list[0]) + + def test_preserves_other_content_properties(self): + input_messages = [ + LangchainHumanMessage( + content=[ + { + "type": "text", + "text": "Hello", + "cache_control": {"type": "ephemeral"}, + "custom_field": "custom_value", + "another_field": 123, + }, + ] + ) + ] + + result = self.summarizer._construct_messages(input_messages) + messages = result.messages[1:-1] + + # Verify other fields are preserved + content = messages[0].content[0] + self.assertEqual(content["custom_field"], "custom_value") + self.assertEqual(content["another_field"], 123) + self.assertNotIn("cache_control", content) + + def test_empty_messages_list(self): + result = self.summarizer._construct_messages([]) + # Should return prompt template with just system and user prompts + self.assertEqual(len(result.messages), 2) diff --git a/ee/hogai/graph/root/compaction_manager.py b/ee/hogai/graph/root/compaction_manager.py index 14d50772c643c..0970871caecbd 100644 --- a/ee/hogai/graph/root/compaction_manager.py +++ b/ee/hogai/graph/root/compaction_manager.py @@ -2,6 +2,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable, Sequence from typing import TYPE_CHECKING, Any, TypeVar +from uuid import uuid4 from langchain_anthropic import ChatAnthropic from langchain_core.language_models import BaseChatModel @@ -10,11 +11,13 @@ HumanMessage as LangchainHumanMessage, ) from langchain_core.tools import BaseTool +from pydantic import BaseModel -from posthog.schema import AssistantMessage, AssistantToolCallMessage, HumanMessage +from posthog.schema import AssistantMessage, AssistantToolCallMessage, ContextMessage, HumanMessage from posthog.sync import database_sync_to_async +from ee.hogai.utils.helpers import find_start_message, find_start_message_idx, insert_messages_before_start from ee.hogai.utils.types import AssistantMessageUnion if TYPE_CHECKING: @@ -25,6 +28,12 @@ LangchainTools = Sequence[dict[str, Any] | type | Callable | BaseTool] +class InsertionResult(BaseModel): + messages: Sequence[AssistantMessageUnion] + updated_start_id: str + updated_window_start_id: str + + class ConversationCompactionManager(ABC): """ Manages conversation window boundaries, message filtering, and summarization decisions. @@ -39,27 +48,27 @@ class ConversationCompactionManager(ABC): Determines the approximate number of characters per token. """ - def find_window_boundary( - self, messages: list[AssistantMessageUnion], max_messages: int = 10, max_tokens: int = 1000 - ) -> str: + def find_window_boundary(self, messages: Sequence[T], max_messages: int = 10, max_tokens: int = 1000) -> str | None: """ Find the optimal window start ID based on message count and token limits. Ensures the window starts at a human or assistant message. """ - new_window_id: str = str(messages[-1].id) + new_window_id: str | None = None for message in reversed(messages): + # Handle limits before assigning the window ID. + max_tokens -= self._get_estimated_tokens(message) + max_messages -= 1 + if max_tokens < 0 or max_messages < 0: + break + + # Assign the new new window ID. if message.id is not None: if isinstance(message, HumanMessage): new_window_id = message.id if isinstance(message, AssistantMessage): new_window_id = message.id - max_messages -= 1 - max_tokens -= self._get_estimated_tokens(message) - if max_messages <= 0 or max_tokens <= 0: - break - return new_window_id def get_messages_in_window(self, messages: Sequence[T], window_start_id: str | None = None) -> Sequence[T]: @@ -84,6 +93,50 @@ async def should_compact_conversation( token_count = await self._get_token_count(model, messages, tools, **kwargs) return token_count > self.CONVERSATION_WINDOW_SIZE + def update_window( + self, messages: Sequence[T], summary_message: ContextMessage, start_id: str | None = None + ) -> InsertionResult: + """Finds the optimal position to insert the summary message in the conversation window.""" + window_start_id_candidate = self.find_window_boundary(messages, max_messages=16, max_tokens=2048) + start_message = find_start_message(messages, start_id) + if not start_message: + raise ValueError("Start message not found") + + start_message_copy = start_message.model_copy(deep=True) + start_message_copy.id = str(uuid4()) + + # The last messages were too large to fit into the window. Copy the last human message to the start of the window. + if not window_start_id_candidate: + return InsertionResult( + messages=[*messages, summary_message, start_message_copy], + updated_start_id=start_message_copy.id, + updated_window_start_id=summary_message.id, + ) + + # Find the updated window + start_message_idx = find_start_message_idx(messages, window_start_id_candidate) + new_window = messages[start_message_idx:] + + # If the start human message is in the window, insert the summary message before it + # and update the window start. + if next((m for m in new_window if m.id == start_id), None): + updated_messages = insert_messages_before_start(messages, [summary_message], start_id=start_id) + return InsertionResult( + messages=updated_messages, + updated_start_id=start_id, + updated_window_start_id=window_start_id_candidate, + ) + + # If the start message is not in the window, insert the summary message and human message at the start of the window. + updated_messages = insert_messages_before_start( + new_window, [summary_message, start_message_copy], start_id=window_start_id_candidate + ) + return InsertionResult( + messages=updated_messages, + updated_start_id=start_message_copy.id, + updated_window_start_id=window_start_id_candidate, + ) + def _get_estimated_tokens(self, message: AssistantMessageUnion) -> int: """ Estimate token count for a message using character/4 heuristic. diff --git a/ee/hogai/graph/root/nodes.py b/ee/hogai/graph/root/nodes.py index 8b8a0369bcd52..c166f772db0e8 100644 --- a/ee/hogai/graph/root/nodes.py +++ b/ee/hogai/graph/root/nodes.py @@ -35,7 +35,7 @@ from ee.hogai.llm import MaxChatAnthropic from ee.hogai.tool import CONTEXTUAL_TOOL_NAME_TO_TOOL, ToolMessagesArtifact from ee.hogai.utils.anthropic import add_cache_control, convert_to_anthropic_messages, normalize_ai_anthropic_message -from ee.hogai.utils.helpers import convert_tool_messages_to_dict, insert_messages_before_start +from ee.hogai.utils.helpers import convert_tool_messages_to_dict from ee.hogai.utils.prompt import format_prompt_string from ee.hogai.utils.types import ( AssistantMessageUnion, @@ -130,6 +130,7 @@ async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAs messages_to_replace or state.messages, state.root_conversation_start_id, state.root_tool_calls_count ) window_id = state.root_conversation_start_id + start_id = state.start_id # Summarize the conversation if it's too long. if await self._window_manager.should_compact_conversation( @@ -144,12 +145,14 @@ async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAs ) # Insert the summary message before the last human message - messages_to_replace = insert_messages_before_start( - messages_to_replace or state.messages, [summary_message], start_id=state.start_id + insertion_result = self._window_manager.update_window( + messages_to_replace or state.messages, summary_message, start_id=start_id ) + window_id = insertion_result.updated_window_start_id + start_id = insertion_result.updated_start_id + messages_to_replace = insertion_result.messages - # Update window - window_id = self._window_manager.find_window_boundary(messages_to_replace) + # Update the window langchain_messages = self._construct_messages(messages_to_replace, window_id, state.root_tool_calls_count) system_prompts = ChatPromptTemplate.from_messages( @@ -174,7 +177,11 @@ async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAs if messages_to_replace: new_messages = ReplaceMessages([*messages_to_replace, assistant_message]) - return PartialAssistantState(root_conversation_start_id=window_id, messages=new_messages) + return PartialAssistantState( + messages=new_messages, + root_conversation_start_id=window_id, + start_id=start_id, + ) async def get_reasoning_message( self, input: BaseState, default_message: Optional[str] = None diff --git a/ee/hogai/graph/root/test/test_compaction_manager.py b/ee/hogai/graph/root/test/test_compaction_manager.py index dedb05a97f56b..b14dfe3aa814c 100644 --- a/ee/hogai/graph/root/test/test_compaction_manager.py +++ b/ee/hogai/graph/root/test/test_compaction_manager.py @@ -1,3 +1,5 @@ +from uuid import uuid4 + from posthog.test.base import BaseTest from unittest.mock import AsyncMock, MagicMock, patch @@ -8,7 +10,7 @@ ) from parameterized import parameterized -from posthog.schema import AssistantMessage, AssistantToolCall, AssistantToolCallMessage, HumanMessage +from posthog.schema import AssistantMessage, AssistantToolCall, AssistantToolCallMessage, ContextMessage, HumanMessage from ee.hogai.graph.root.compaction_manager import AnthropicConversationCompactionManager from ee.hogai.utils.types.base import AssistantMessageUnion @@ -58,7 +60,7 @@ def test_find_window_boundary_token_limit(self): # Set token limit that forces early stop # Works backwards: processes message 3 (~2 tokens), then message 2 (~250 tokens) which breaks the limit result = self.window_manager.find_window_boundary(messages, max_messages=10, max_tokens=100) - self.assertEqual(result, "2") + self.assertEqual(result, "3") def test_find_window_boundary_stops_at_human_or_assistant(self): """Test window boundary ensures it starts at human or assistant message""" @@ -182,3 +184,378 @@ async def test_get_token_count_calls_model(self): self.assertEqual(result, 1234) mock_model.get_num_tokens_from_messages.assert_called_once_with(messages, thinking=thinking_config, tools=None) + + def test_update_window_with_large_last_tool_call_message(self): + """ + Test that update_window handles a large (128k) final AssistantToolCallMessage. + When the last messages are too large to fit into the window, the start human message + should be copied to the start of the window along with the summary message. + """ + # Create a very large tool call message (128k characters) + large_content = "x" * (128 * 1024) + start_id = str(uuid4()) + summary_id = str(uuid4()) + + messages: list[AssistantMessageUnion] = [ + HumanMessage(content="Initial question", id=start_id), + AssistantMessage( + content="Let me analyze that", + tool_calls=[ + AssistantToolCall( + id="tool-1", + name="create_and_query_insight", + args={"query_description": "test"}, + ) + ], + ), + AssistantToolCallMessage( + content=large_content, + tool_call_id="tool-1", + ), + ] + + summary_message = ContextMessage(content="Summary of previous conversation", id=summary_id) + + result = self.window_manager.update_window(messages, summary_message, start_id=start_id) + + # When the window boundary is None (messages too large), we expect: + # - Original messages preserved + # - Summary message appended + # - Start message copied + # - Window start should be the summary message + self.assertEqual(len(result.messages), 5) + self.assertEqual(result.messages[0].id, start_id) + self.assertEqual(result.messages[-2].id, summary_id) + last_msg = result.messages[-1] + assert isinstance(last_msg, HumanMessage) # Type narrowing + self.assertEqual(last_msg.content, "Initial question") + self.assertNotEqual(last_msg.id, start_id) + self.assertEqual(result.updated_start_id, last_msg.id) + self.assertEqual(result.updated_window_start_id, summary_id) + + def test_update_window_initiator_in_window(self): + """ + Test update_window when the initiator (start) human message is within the new window boundary. + In this case, the summary should be inserted before the start message, + and the window start should be updated to the found boundary. + """ + start_id = str(uuid4()) + summary_id = str(uuid4()) + + # Create a conversation where the start message will be in the window + messages: list[AssistantMessageUnion] = [ + HumanMessage(content="Old question 1", id=str(uuid4())), + AssistantMessage(content="Old response 1"), + HumanMessage(content="Old question 2", id=str(uuid4())), + AssistantMessage(content="Old response 2"), + HumanMessage(content="Recent question", id=start_id), + AssistantMessage(content="Recent response"), + ] + + summary_message = ContextMessage(content="Summary of conversation", id=summary_id) + + result = self.window_manager.update_window(messages, summary_message, start_id=start_id) + + # The start message is in the window, so summary should be inserted before it + # Find where the summary was inserted + summary_idx = next(i for i, msg in enumerate(result.messages) if msg.id == summary_id) + start_idx = next(i for i, msg in enumerate(result.messages) if msg.id == start_id) + + # Summary should come before start message + self.assertLess(summary_idx, start_idx) + # Start ID should remain the same + self.assertEqual(result.updated_start_id, start_id) + # Window start should be set to a boundary candidate + self.assertIsNotNone(result.updated_window_start_id) + + def test_update_window_initiator_not_in_window(self): + """ + Test update_window when the initiator (start) human message is NOT in the new window boundary. + In this case, a copy of the start message should be inserted at the window start, + along with the summary message. + """ + start_id = str(uuid4()) + summary_id = str(uuid4()) + + # Create many messages to push the start message outside the window + # The window boundary is determined by max_messages=16 and max_tokens=2048 + messages: list[AssistantMessageUnion] = [ + HumanMessage(content="Initial question", id=start_id), + AssistantMessage(content="Initial response"), + ] + + # Add enough messages to push start_id out of the window + for i in range(20): + messages.append(HumanMessage(content=f"Question {i}", id=str(uuid4()))) + messages.append(AssistantMessage(content=f"Response {i}" * 50)) # Make messages larger + + summary_message = ContextMessage(content="Summary", id=summary_id) + + result = self.window_manager.update_window(messages, summary_message, start_id=start_id) + + # The start message should NOT be in the result (only its copy) + start_messages = [msg for msg in result.messages if msg.id == start_id] + self.assertEqual(len(start_messages), 0, "Original start message should not be in result") + + # Find the copied start message (same content, different ID) + copied_start = next( + (msg for msg in result.messages if isinstance(msg, HumanMessage) and msg.content == "Initial question"), + None, + ) + self.assertIsNotNone(copied_start, "Copied start message should exist") + assert copied_start is not None # Type narrowing + self.assertNotEqual(copied_start.id, start_id, "Copied message should have new ID") + + # The copied start message should have a new ID returned + self.assertEqual(result.updated_start_id, copied_start.id) + + # Summary and copied start should be at the beginning of the window + summary_idx = next(i for i, msg in enumerate(result.messages) if msg.id == summary_id) + self.assertEqual(result.messages[summary_idx + 1].id, copied_start.id, "Copied start should follow summary") + + def test_tool_call_complete_sequence_in_window(self): + """ + Test that complete tool call sequences within the window boundary are preserved. + When both AssistantMessage with tool_calls and AssistantToolCallMessage are in + the window, they should both be preserved. + """ + start_id = str(uuid4()) + messages: list[AssistantMessageUnion] = [ + HumanMessage(content="Old message", id=str(uuid4())), + AssistantMessage(content="Old response"), + HumanMessage(content="Recent question", id=start_id), + AssistantMessage( + content="Let me check", + tool_calls=[ + AssistantToolCall( + id="tool-1", + name="create_and_query_insight", + args={"query": "test"}, + ) + ], + ), + AssistantToolCallMessage(content="Tool result", tool_call_id="tool-1"), + AssistantMessage(content="Final response"), + ] + + summary_message = ContextMessage(content="Summary", id=str(uuid4())) + + result = self.window_manager.update_window(messages, summary_message, start_id=start_id) + + # Count tool calls in output + tool_call_count = 0 + tool_result_count = 0 + for msg in result.messages: + if isinstance(msg, AssistantMessage) and msg.tool_calls: + tool_call_count += len(msg.tool_calls) + elif isinstance(msg, AssistantToolCallMessage): + tool_result_count += 1 + + # All tool calls and results should be preserved + self.assertEqual(tool_call_count, tool_result_count, "Tool calls and results should match in output") + self.assertGreater(tool_call_count, 0, "Should preserve at least some tool calls") + + def test_tool_call_incomplete_at_window_boundary(self): + """ + Test that incomplete tool call sequences at the window boundary are handled correctly. + When tool call sequences are split by the window boundary, the system should maintain + consistency (either preserve both parts or remove both). + """ + start_id = str(uuid4()) + messages: list[AssistantMessageUnion] = [ + HumanMessage(content="Question 1", id=str(uuid4())), + AssistantMessage( + content="", + tool_calls=[ + AssistantToolCall( + id="tool-old", + name="create_and_query_insight", + args={"query": "old"}, + ) + ], + ), + AssistantToolCallMessage(content="Result", tool_call_id="tool-old"), + # Add many messages to push above out of window + HumanMessage(content="Q2", id=str(uuid4())), + AssistantMessage(content="R2" * 100), + HumanMessage(content="Q3", id=str(uuid4())), + AssistantMessage(content="R3" * 100), + HumanMessage(content="Q4", id=str(uuid4())), + AssistantMessage(content="R4" * 100), + HumanMessage(content="Q5", id=str(uuid4())), + AssistantMessage(content="R5" * 100), + HumanMessage(content="Q6", id=start_id), + AssistantMessage( + content="", + tool_calls=[ + AssistantToolCall( + id="tool-new", + name="create_and_query_insight", + args={"query": "new"}, + ) + ], + ), + AssistantToolCallMessage(content="New result", tool_call_id="tool-new"), + ] + + summary_message = ContextMessage(content="Summary", id=str(uuid4())) + + result = self.window_manager.update_window(messages, summary_message, start_id=start_id) + + # Count tool calls in output + tool_call_count = 0 + tool_result_count = 0 + for msg in result.messages: + if isinstance(msg, AssistantMessage) and msg.tool_calls: + tool_call_count += len(msg.tool_calls) + elif isinstance(msg, AssistantToolCallMessage): + tool_result_count += 1 + + # Even when removing incomplete sequences, remaining should be complete + self.assertEqual( + tool_call_count, + tool_result_count, + "Even when removing incomplete sequences, remaining should be complete", + ) + + def test_tool_call_multiple_complete_sequences(self): + """ + Test that multiple complete tool call sequences are all preserved. + When there are multiple consecutive tool calls, all complete sequences + should be maintained in the output. + """ + start_id = str(uuid4()) + messages: list[AssistantMessageUnion] = [ + HumanMessage(content="Question", id=start_id), + AssistantMessage( + content="", + tool_calls=[ + AssistantToolCall( + id="tool-1", + name="create_and_query_insight", + args={"query": "first"}, + ) + ], + ), + AssistantToolCallMessage(content="First result", tool_call_id="tool-1"), + AssistantMessage( + content="", + tool_calls=[ + AssistantToolCall( + id="tool-2", + name="create_and_query_insight", + args={"query": "second"}, + ) + ], + ), + AssistantToolCallMessage(content="Second result", tool_call_id="tool-2"), + AssistantMessage(content="Done"), + ] + + summary_message = ContextMessage(content="Summary", id=str(uuid4())) + + result = self.window_manager.update_window(messages, summary_message, start_id=start_id) + + # Count tool calls in output + tool_call_count = 0 + tool_result_count = 0 + for msg in result.messages: + if isinstance(msg, AssistantMessage) and msg.tool_calls: + tool_call_count += len(msg.tool_calls) + elif isinstance(msg, AssistantToolCallMessage): + tool_result_count += 1 + + # All complete sequences should be preserved + self.assertEqual(tool_call_count, tool_result_count, "Tool calls and results should match in output") + self.assertEqual(tool_call_count, 2, "Should preserve both tool calls") + + def test_update_window_with_empty_messages(self): + """Test that update_window handles edge case of empty messages list""" + summary_message = ContextMessage(content="Summary", id=str(uuid4())) + + # This should raise ValueError because there's no start message + with self.assertRaises(ValueError) as context: + self.window_manager.update_window([], summary_message, start_id="nonexistent") + + self.assertIn("Start message not found", str(context.exception)) + + def test_update_window_with_nonexistent_start_id(self): + """ + Test that update_window handles a nonexistent start_id. + When start_id doesn't exist, find_start_message falls back to the first HumanMessage. + """ + actual_id = str(uuid4()) + messages: list[AssistantMessageUnion] = [ + HumanMessage(content="Question", id=actual_id), + AssistantMessage(content="Response"), + ] + + summary_message = ContextMessage(content="Summary", id=str(uuid4())) + + # When start_id doesn't exist, it falls back to the first human message + result = self.window_manager.update_window(messages, summary_message, start_id="nonexistent-id") + + # The first human message should be used as the start message + self.assertIsNotNone(result) + # The actual_id message should be found and used + found_actual_id = any(msg.id == actual_id for msg in result.messages) + self.assertTrue(found_actual_id, "Should fall back to first human message when start_id not found") + + def test_update_window_preserves_message_ids(self): + """Test that all messages in the result have valid IDs""" + start_id = str(uuid4()) + summary_id = str(uuid4()) + + messages: list[AssistantMessageUnion] = [ + HumanMessage(content="Question", id=start_id), + AssistantMessage(content="Response", id=str(uuid4())), + ] + + summary_message = ContextMessage(content="Summary", id=summary_id) + + result = self.window_manager.update_window(messages, summary_message, start_id=start_id) + + # Verify all messages have IDs + for msg in result.messages: + self.assertIsNotNone(msg.id, f"Message should have an ID: {msg}") + self.assertIsInstance(msg.id, str, f"Message ID should be a string: {msg.id}") + + def test_update_window_with_no_window_boundary(self): + """Test update_window when messages are too large to fit in window""" + start_id = str(uuid4()) + summary_id = str(uuid4()) + + # Create messages with large content that will exceed the window + messages: list[AssistantMessageUnion] = [ + HumanMessage(content="Question", id=start_id), + AssistantMessage(content="x" * 10000), # Large message + ] + + summary_message = ContextMessage(content="Summary", id=summary_id) + + result = self.window_manager.update_window(messages, summary_message, start_id=start_id) + + # When there's no window boundary, the summary and copied start message are appended + self.assertEqual(len(result.messages), 4) # original 2 + summary + copied start + self.assertEqual(result.messages[-2].id, summary_id) + self.assertEqual(result.updated_window_start_id, summary_id) + # Updated start ID should be the copied message + self.assertNotEqual(result.updated_start_id, start_id) + + def test_update_window_single_message_conversation(self): + """Test update_window with a minimal single-message conversation""" + start_id = str(uuid4()) + summary_id = str(uuid4()) + + messages: list[AssistantMessageUnion] = [ + HumanMessage(content="Question", id=start_id), + ] + + summary_message = ContextMessage(content="Summary", id=summary_id) + + result = self.window_manager.update_window(messages, summary_message, start_id=start_id) + + # Should insert summary before the start message + self.assertGreater(len(result.messages), 1) + summary_idx = next(i for i, msg in enumerate(result.messages) if msg.id == summary_id) + self.assertIsNotNone(summary_idx) diff --git a/ee/hogai/test/test_assistant.py b/ee/hogai/test/test_assistant.py index cec697e21bf89..1c6b0bf8c57b3 100644 --- a/ee/hogai/test/test_assistant.py +++ b/ee/hogai/test/test_assistant.py @@ -2287,3 +2287,76 @@ def create_messages_with_ids(_): assistant_output = [(event_type, msg) for event_type, msg in output if isinstance(msg, AssistantMessage)] self.assertEqual(len(assistant_output), 1) self.assertEqual(cast(AssistantMessage, assistant_output[0][1]).id, message_id_2) + + @patch( + "ee.hogai.graph.conversation_summarizer.nodes.AnthropicConversationSummarizer.summarize", + new=AsyncMock(return_value="Summary"), + ) + @patch("ee.hogai.graph.root.compaction_manager.AnthropicConversationCompactionManager.should_compact_conversation") + @patch("ee.hogai.graph.root.tools.read_taxonomy.ReadTaxonomyTool._run_impl") + @patch("ee.hogai.graph.root.nodes.RootNode._get_model") + async def test_compacting_conversation_on_the_second_turn(self, mock_model, mock_tool, mock_should_compact): + mock_model.side_effect = cycle( # Changed from return_value to side_effect + [ + FakeChatAnthropic( + responses=[ + messages.AIMessage( + content=[{"text": "Let me think about that", "type": "text"}], + tool_calls=[{"id": "1", "name": "read_taxonomy", "args": {"query": {"kind": "events"}}}], + ) + ] + ), + FakeChatAnthropic( + responses=[ + messages.AIMessage( + content=[{"text": "After summary", "type": "text"}], + ) + ] + ), + ] + ) + mock_tool.return_value = ("Event list" * 128000, None) + mock_should_compact.side_effect = cycle([False, True]) # Also changed this + + graph = ( + AssistantGraph(self.team, self.user) + .add_root( + path_map={ + "insights": AssistantNodeName.END, + "search_documentation": AssistantNodeName.END, + "root": AssistantNodeName.ROOT, + "end": AssistantNodeName.END, + "insights_search": AssistantNodeName.END, + "session_summarization": AssistantNodeName.END, + "create_dashboard": AssistantNodeName.END, + } + ) + .add_memory_onboarding() + .compile() + ) + + expected_output = [ + ("message", HumanMessage(content="First")), + ( + "message", + AssistantMessage( + content="Let me think about that", + tool_calls=[{"id": "1", "name": "read_taxonomy", "args": {"query": {"kind": "events"}}}], + ), + ), + ("message", ReasoningMessage(content="Searching the taxonomy")), + ("message", AssistantToolCallMessage(tool_call_id="1", content="Event list" * 128000)), + ("message", HumanMessage(content="First")), # Should copy this message + ("message", AssistantMessage(content="After summary")), + ] + + output, _ = await self._run_assistant_graph(graph, message="First", conversation=self.conversation) + self.assertConversationEqual(output, expected_output) + + snapshot = await graph.aget_state({"configurable": {"thread_id": str(self.conversation.id)}}) + state = AssistantState.model_validate(snapshot.values) + # should be equal to the copied human message + new_human_message = cast(HumanMessage, output[4][1]) + self.assertEqual(state.start_id, new_human_message.id) + # should be equal to the summary message, minus reasoning message + self.assertEqual(state.root_conversation_start_id, state.messages[3].id) diff --git a/ee/hogai/utils/helpers.py b/ee/hogai/utils/helpers.py index bb8bed77535dc..e66e098093df3 100644 --- a/ee/hogai/utils/helpers.py +++ b/ee/hogai/utils/helpers.py @@ -100,7 +100,8 @@ def dereference_schema(schema: dict) -> dict: def find_start_message_idx(messages: Sequence[AssistantMessageUnion], start_id: str | None = None) -> int: - for idx, msg in enumerate(messages): + for idx in range(len(messages) - 1, -1, -1): + msg = messages[idx] if isinstance(msg, HumanMessage) and msg.id == start_id: return idx return 0 diff --git a/ee/hogai/utils/test/test_assistant_helpers.py b/ee/hogai/utils/test/test_assistant_helpers.py index a106b2cf29ce6..614a6e3fe68bf 100644 --- a/ee/hogai/utils/test/test_assistant_helpers.py +++ b/ee/hogai/utils/test/test_assistant_helpers.py @@ -1,8 +1,16 @@ from posthog.test.base import BaseTest -from posthog.schema import AssistantMessage, AssistantToolCallMessage +from parameterized import parameterized -from ee.hogai.utils.helpers import should_output_assistant_message +from posthog.schema import ( + AssistantMessage, + AssistantToolCallMessage, + AssistantTrendsQuery, + HumanMessage, + VisualizationMessage, +) + +from ee.hogai.utils.helpers import find_start_message, find_start_message_idx, should_output_assistant_message class TestAssistantHelpers(BaseTest): @@ -33,3 +41,108 @@ def test_should_output_assistant_message(self): content="Tool result", tool_call_id="456", type="tool", ui_payload=None ) self.assertFalse(should_output_assistant_message(tool_message_without_payload)) + + @parameterized.expand( + [ + ("no_start_id", [], None, 0), + ("empty_messages", [], "some-id", 0), + ("start_id_not_found", [HumanMessage(content="test", id="other-id")], "target-id", 0), + ( + "single_matching_message", + [HumanMessage(content="test", id="target-id")], + "target-id", + 0, + ), + ( + "matching_message_at_end", + [ + AssistantMessage(content="response", type="ai"), + HumanMessage(content="question", id="target-id"), + ], + "target-id", + 1, + ), + ( + "matching_message_in_middle", + [ + HumanMessage(content="first", id="first-id"), + HumanMessage(content="second", id="target-id"), + AssistantMessage(content="response", type="ai"), + ], + "target-id", + 1, + ), + ( + "multiple_human_messages_match_first_from_end", + [ + HumanMessage(content="first", id="other-id"), + HumanMessage(content="second", id="target-id"), + AssistantMessage(content="response", type="ai"), + HumanMessage(content="third", id="another-id"), + ], + "target-id", + 1, + ), + ( + "non_human_message_with_matching_id_ignored", + [ + AssistantMessage(content="response", type="ai", id="target-id"), + HumanMessage(content="question", id="other-id"), + ], + "target-id", + 0, + ), + ( + "mixed_messages_finds_correct_human_message", + [ + HumanMessage(content="first", id="first-id"), + AssistantMessage(content="response 1", type="ai"), + VisualizationMessage(answer=AssistantTrendsQuery(series=[])), + HumanMessage(content="second", id="target-id"), + AssistantMessage(content="response 2", type="ai"), + ], + "target-id", + 3, + ), + ] + ) + def test_find_start_message_idx(self, _name, messages, start_id, expected_idx): + result = find_start_message_idx(messages, start_id) + self.assertEqual(result, expected_idx) + + @parameterized.expand( + [ + ("empty_messages", [], None, None), + ( + "returns_first_message_when_no_start_id", + [ + HumanMessage(content="first", id="first-id"), + AssistantMessage(content="response", type="ai"), + ], + None, + HumanMessage(content="first", id="first-id"), + ), + ( + "returns_matching_message", + [ + HumanMessage(content="first", id="first-id"), + HumanMessage(content="second", id="target-id"), + AssistantMessage(content="response", type="ai"), + ], + "target-id", + HumanMessage(content="second", id="target-id"), + ), + ( + "returns_first_when_id_not_found", + [ + HumanMessage(content="first", id="first-id"), + AssistantMessage(content="response", type="ai"), + ], + "nonexistent-id", + HumanMessage(content="first", id="first-id"), + ), + ] + ) + def test_find_start_message(self, _name, messages, start_id, expected_message): + result = find_start_message(messages, start_id) + self.assertEqual(result, expected_message)