Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions ee/hogai/graph/conversation_summarizer/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,7 @@ def __init__(self, team: Team, user: User):
self._team = team

async def summarize(self, messages: Sequence[BaseMessage]) -> str:
prompt = (
ChatPromptTemplate.from_messages([("system", SYSTEM_PROMPT)])
+ messages
+ ChatPromptTemplate.from_messages([("user", USER_PROMPT)])
)
prompt = self._construct_messages(messages)
model = self._get_model()
chain = prompt | model | StrOutputParser() | self._parse_xml_tags
response: str = await chain.ainvoke({}) # Do not pass config here, so the node doesn't stream
Expand All @@ -31,6 +27,13 @@ async def summarize(self, messages: Sequence[BaseMessage]) -> str:
@abstractmethod
def _get_model(self): ...

def _construct_messages(self, messages: Sequence[BaseMessage]):
return (
ChatPromptTemplate.from_messages([("system", SYSTEM_PROMPT)])
+ messages
+ ChatPromptTemplate.from_messages([("user", USER_PROMPT)])
)

def _parse_xml_tags(self, message: str) -> str:
"""
Extract analysis and summary tags from a message.
Expand All @@ -54,11 +57,24 @@ def _parse_xml_tags(self, message: str) -> str:
class AnthropicConversationSummarizer(ConversationSummarizer):
def _get_model(self):
return MaxChatAnthropic(
model="claude-sonnet-4-5",
model="claude-haiku-4-5",
streaming=False,
stream_usage=False,
max_tokens=8192,
disable_streaming=True,
user=self._user,
team=self._team,
)

def _construct_messages(self, messages: Sequence[BaseMessage]):
"""Removes cache_control headers."""
messages_without_cache: list[BaseMessage] = []
for message in messages:
if isinstance(message.content, list):
message = message.model_copy(deep=True)
for content in message.content:
if isinstance(content, dict) and "cache_control" in content:
content.pop("cache_control")
messages_without_cache.append(message)

return super()._construct_messages(messages_without_cache)
Empty file.
177 changes: 177 additions & 0 deletions ee/hogai/graph/conversation_summarizer/test/test_nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
from typing import Any, cast

from posthog.test.base import BaseTest

from langchain_core.messages import (
AIMessage as LangchainAIMessage,
HumanMessage as LangchainHumanMessage,
)
from parameterized import parameterized

from ee.hogai.graph.conversation_summarizer.nodes import AnthropicConversationSummarizer


class TestAnthropicConversationSummarizer(BaseTest):
def setUp(self):
super().setUp()
self.summarizer = AnthropicConversationSummarizer(team=self.team, user=self.user)

@parameterized.expand(
[
(
"single_message_with_cache_control",
[
LangchainHumanMessage(
content=[
{"type": "text", "text": "Hello", "cache_control": {"type": "ephemeral"}},
]
)
],
[[{"type": "text", "text": "Hello"}]],
),
(
"multiple_items_with_cache_control",
[
LangchainAIMessage(
content=[
{"type": "text", "text": "First", "cache_control": {"type": "ephemeral"}},
{"type": "text", "text": "Second", "cache_control": {"type": "ephemeral"}},
]
)
],
[[{"type": "text", "text": "First"}, {"type": "text", "text": "Second"}]],
),
(
"mixed_items_some_with_cache_control",
[
LangchainHumanMessage(
content=[
{"type": "text", "text": "With cache", "cache_control": {"type": "ephemeral"}},
{"type": "text", "text": "Without cache"},
]
)
],
[[{"type": "text", "text": "With cache"}, {"type": "text", "text": "Without cache"}]],
),
(
"multiple_messages_with_cache_control",
[
LangchainHumanMessage(
content=[
{"type": "text", "text": "Message 1", "cache_control": {"type": "ephemeral"}},
]
),
LangchainAIMessage(
content=[
{"type": "text", "text": "Message 2", "cache_control": {"type": "ephemeral"}},
]
),
],
[
[{"type": "text", "text": "Message 1"}],
[{"type": "text", "text": "Message 2"}],
],
),
]
)
def test_removes_cache_control(self, name, input_messages, expected_contents):
result = self.summarizer._construct_messages(input_messages)

# Extract the actual messages from the prompt template
messages = result.messages[1:-1] # Skip system prompt and user prompt

self.assertEqual(len(messages), len(expected_contents), f"Wrong number of messages in test case: {name}")

for i, (message, expected_content) in enumerate(zip(messages, expected_contents)):
self.assertEqual(
message.content,
expected_content,
f"Message {i} content mismatch in test case: {name}",
)

@parameterized.expand(
[
(
"string_content",
[LangchainHumanMessage(content="Simple string")],
),
(
"empty_list_content",
[LangchainHumanMessage(content=[])],
),
(
"non_dict_items_in_list",
[LangchainHumanMessage(content=["string_item", {"type": "text", "text": "dict_item"}])],
),
]
)
def test_handles_non_dict_content_without_errors(self, name, input_messages):
result = self.summarizer._construct_messages(input_messages)
self.assertIsNotNone(result)

def test_original_message_not_modified(self):
original_content: list[str | dict[Any, Any]] = [
{"type": "text", "text": "Hello", "cache_control": {"type": "ephemeral"}},
]
message = LangchainHumanMessage(content=original_content)

# Store the original cache_control to verify it's not modified
content_list = cast(list[dict[str, Any]], message.content)
self.assertIn("cache_control", content_list[0])

self.summarizer._construct_messages([message])

# Verify original message still has cache_control
content_list = cast(list[dict[str, Any]], message.content)
self.assertIn("cache_control", content_list[0])
self.assertEqual(content_list[0]["cache_control"], {"type": "ephemeral"})

def test_deep_copy_prevents_modification(self):
original_content: list[str | dict[Any, Any]] = [
{
"type": "text",
"text": "Test",
"cache_control": {"type": "ephemeral"},
"other_key": "value",
},
]
message = LangchainHumanMessage(content=original_content)

content_list = cast(list[dict[str, Any]], message.content)
original_keys = set(content_list[0].keys())

self.summarizer._construct_messages([message])

# Verify original message structure unchanged
content_list = cast(list[dict[str, Any]], message.content)
self.assertEqual(set(content_list[0].keys()), original_keys)
self.assertIn("cache_control", content_list[0])

def test_preserves_other_content_properties(self):
input_messages = [
LangchainHumanMessage(
content=[
{
"type": "text",
"text": "Hello",
"cache_control": {"type": "ephemeral"},
"custom_field": "custom_value",
"another_field": 123,
},
]
)
]

result = self.summarizer._construct_messages(input_messages)
messages = result.messages[1:-1]

# Verify other fields are preserved
content = messages[0].content[0]
self.assertEqual(content["custom_field"], "custom_value")
self.assertEqual(content["another_field"], 123)
self.assertNotIn("cache_control", content)

def test_empty_messages_list(self):
result = self.summarizer._construct_messages([])
# Should return prompt template with just system and user prompts
self.assertEqual(len(result.messages), 2)
73 changes: 63 additions & 10 deletions ee/hogai/graph/root/compaction_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any, TypeVar
from uuid import uuid4

from langchain_anthropic import ChatAnthropic
from langchain_core.language_models import BaseChatModel
Expand All @@ -10,11 +11,13 @@
HumanMessage as LangchainHumanMessage,
)
from langchain_core.tools import BaseTool
from pydantic import BaseModel

from posthog.schema import AssistantMessage, AssistantToolCallMessage, HumanMessage
from posthog.schema import AssistantMessage, AssistantToolCallMessage, ContextMessage, HumanMessage

from posthog.sync import database_sync_to_async

from ee.hogai.utils.helpers import find_start_message, find_start_message_idx, insert_messages_before_start
from ee.hogai.utils.types import AssistantMessageUnion

if TYPE_CHECKING:
Expand All @@ -25,6 +28,12 @@
LangchainTools = Sequence[dict[str, Any] | type | Callable | BaseTool]


class InsertionResult(BaseModel):
messages: Sequence[AssistantMessageUnion]
updated_start_id: str
updated_window_start_id: str


class ConversationCompactionManager(ABC):
"""
Manages conversation window boundaries, message filtering, and summarization decisions.
Expand All @@ -39,27 +48,27 @@ class ConversationCompactionManager(ABC):
Determines the approximate number of characters per token.
"""

def find_window_boundary(
self, messages: list[AssistantMessageUnion], max_messages: int = 10, max_tokens: int = 1000
) -> str:
def find_window_boundary(self, messages: Sequence[T], max_messages: int = 10, max_tokens: int = 1000) -> str | None:
"""
Find the optimal window start ID based on message count and token limits.
Ensures the window starts at a human or assistant message.
"""

new_window_id: str = str(messages[-1].id)
new_window_id: str | None = None
for message in reversed(messages):
# Handle limits before assigning the window ID.
max_tokens -= self._get_estimated_tokens(message)
max_messages -= 1
if max_tokens < 0 or max_messages < 0:
break

# Assign the new new window ID.
if message.id is not None:
if isinstance(message, HumanMessage):
new_window_id = message.id
if isinstance(message, AssistantMessage):
new_window_id = message.id

max_messages -= 1
max_tokens -= self._get_estimated_tokens(message)
if max_messages <= 0 or max_tokens <= 0:
break

return new_window_id

def get_messages_in_window(self, messages: Sequence[T], window_start_id: str | None = None) -> Sequence[T]:
Expand All @@ -84,6 +93,50 @@ async def should_compact_conversation(
token_count = await self._get_token_count(model, messages, tools, **kwargs)
return token_count > self.CONVERSATION_WINDOW_SIZE

def update_window(
self, messages: Sequence[T], summary_message: ContextMessage, start_id: str | None = None
) -> InsertionResult:
"""Finds the optimal position to insert the summary message in the conversation window."""
window_start_id_candidate = self.find_window_boundary(messages, max_messages=16, max_tokens=2048)
start_message = find_start_message(messages, start_id)
if not start_message:
raise ValueError("Start message not found")

start_message_copy = start_message.model_copy(deep=True)
start_message_copy.id = str(uuid4())

# The last messages were too large to fit into the window. Copy the last human message to the start of the window.
if not window_start_id_candidate:
return InsertionResult(
messages=[*messages, summary_message, start_message_copy],
updated_start_id=start_message_copy.id,
updated_window_start_id=summary_message.id,
)

# Find the updated window
start_message_idx = find_start_message_idx(messages, window_start_id_candidate)
new_window = messages[start_message_idx:]

# If the start human message is in the window, insert the summary message before it
# and update the window start.
if next((m for m in new_window if m.id == start_id), None):
updated_messages = insert_messages_before_start(messages, [summary_message], start_id=start_id)
return InsertionResult(
messages=updated_messages,
updated_start_id=start_id,
updated_window_start_id=window_start_id_candidate,
)

# If the start message is not in the window, insert the summary message and human message at the start of the window.
updated_messages = insert_messages_before_start(
new_window, [summary_message, start_message_copy], start_id=window_start_id_candidate
)
return InsertionResult(
messages=updated_messages,
updated_start_id=start_message_copy.id,
updated_window_start_id=window_start_id_candidate,
)

def _get_estimated_tokens(self, message: AssistantMessageUnion) -> int:
"""
Estimate token count for a message using character/4 heuristic.
Expand Down
Loading
Loading