Skip to content

Commit d8b4be1

Browse files
authored
fix(max): improve compaction (#39929)
1 parent 695e1d0 commit d8b4be1

File tree

9 files changed

+844
-27
lines changed

9 files changed

+844
-27
lines changed

ee/hogai/graph/conversation_summarizer/nodes.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,7 @@ def __init__(self, team: Team, user: User):
1818
self._team = team
1919

2020
async def summarize(self, messages: Sequence[BaseMessage]) -> str:
21-
prompt = (
22-
ChatPromptTemplate.from_messages([("system", SYSTEM_PROMPT)])
23-
+ messages
24-
+ ChatPromptTemplate.from_messages([("user", USER_PROMPT)])
25-
)
21+
prompt = self._construct_messages(messages)
2622
model = self._get_model()
2723
chain = prompt | model | StrOutputParser() | self._parse_xml_tags
2824
response: str = await chain.ainvoke({}) # Do not pass config here, so the node doesn't stream
@@ -31,6 +27,13 @@ async def summarize(self, messages: Sequence[BaseMessage]) -> str:
3127
@abstractmethod
3228
def _get_model(self): ...
3329

30+
def _construct_messages(self, messages: Sequence[BaseMessage]):
31+
return (
32+
ChatPromptTemplate.from_messages([("system", SYSTEM_PROMPT)])
33+
+ messages
34+
+ ChatPromptTemplate.from_messages([("user", USER_PROMPT)])
35+
)
36+
3437
def _parse_xml_tags(self, message: str) -> str:
3538
"""
3639
Extract analysis and summary tags from a message.
@@ -54,11 +57,24 @@ def _parse_xml_tags(self, message: str) -> str:
5457
class AnthropicConversationSummarizer(ConversationSummarizer):
5558
def _get_model(self):
5659
return MaxChatAnthropic(
57-
model="claude-sonnet-4-5",
60+
model="claude-haiku-4-5",
5861
streaming=False,
5962
stream_usage=False,
6063
max_tokens=8192,
6164
disable_streaming=True,
6265
user=self._user,
6366
team=self._team,
6467
)
68+
69+
def _construct_messages(self, messages: Sequence[BaseMessage]):
70+
"""Removes cache_control headers."""
71+
messages_without_cache: list[BaseMessage] = []
72+
for message in messages:
73+
if isinstance(message.content, list):
74+
message = message.model_copy(deep=True)
75+
for content in message.content:
76+
if isinstance(content, dict) and "cache_control" in content:
77+
content.pop("cache_control")
78+
messages_without_cache.append(message)
79+
80+
return super()._construct_messages(messages_without_cache)

ee/hogai/graph/conversation_summarizer/test/__init__.py

Whitespace-only changes.
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
from typing import Any, cast
2+
3+
from posthog.test.base import BaseTest
4+
5+
from langchain_core.messages import (
6+
AIMessage as LangchainAIMessage,
7+
HumanMessage as LangchainHumanMessage,
8+
)
9+
from parameterized import parameterized
10+
11+
from ee.hogai.graph.conversation_summarizer.nodes import AnthropicConversationSummarizer
12+
13+
14+
class TestAnthropicConversationSummarizer(BaseTest):
15+
def setUp(self):
16+
super().setUp()
17+
self.summarizer = AnthropicConversationSummarizer(team=self.team, user=self.user)
18+
19+
@parameterized.expand(
20+
[
21+
(
22+
"single_message_with_cache_control",
23+
[
24+
LangchainHumanMessage(
25+
content=[
26+
{"type": "text", "text": "Hello", "cache_control": {"type": "ephemeral"}},
27+
]
28+
)
29+
],
30+
[[{"type": "text", "text": "Hello"}]],
31+
),
32+
(
33+
"multiple_items_with_cache_control",
34+
[
35+
LangchainAIMessage(
36+
content=[
37+
{"type": "text", "text": "First", "cache_control": {"type": "ephemeral"}},
38+
{"type": "text", "text": "Second", "cache_control": {"type": "ephemeral"}},
39+
]
40+
)
41+
],
42+
[[{"type": "text", "text": "First"}, {"type": "text", "text": "Second"}]],
43+
),
44+
(
45+
"mixed_items_some_with_cache_control",
46+
[
47+
LangchainHumanMessage(
48+
content=[
49+
{"type": "text", "text": "With cache", "cache_control": {"type": "ephemeral"}},
50+
{"type": "text", "text": "Without cache"},
51+
]
52+
)
53+
],
54+
[[{"type": "text", "text": "With cache"}, {"type": "text", "text": "Without cache"}]],
55+
),
56+
(
57+
"multiple_messages_with_cache_control",
58+
[
59+
LangchainHumanMessage(
60+
content=[
61+
{"type": "text", "text": "Message 1", "cache_control": {"type": "ephemeral"}},
62+
]
63+
),
64+
LangchainAIMessage(
65+
content=[
66+
{"type": "text", "text": "Message 2", "cache_control": {"type": "ephemeral"}},
67+
]
68+
),
69+
],
70+
[
71+
[{"type": "text", "text": "Message 1"}],
72+
[{"type": "text", "text": "Message 2"}],
73+
],
74+
),
75+
]
76+
)
77+
def test_removes_cache_control(self, name, input_messages, expected_contents):
78+
result = self.summarizer._construct_messages(input_messages)
79+
80+
# Extract the actual messages from the prompt template
81+
messages = result.messages[1:-1] # Skip system prompt and user prompt
82+
83+
self.assertEqual(len(messages), len(expected_contents), f"Wrong number of messages in test case: {name}")
84+
85+
for i, (message, expected_content) in enumerate(zip(messages, expected_contents)):
86+
self.assertEqual(
87+
message.content,
88+
expected_content,
89+
f"Message {i} content mismatch in test case: {name}",
90+
)
91+
92+
@parameterized.expand(
93+
[
94+
(
95+
"string_content",
96+
[LangchainHumanMessage(content="Simple string")],
97+
),
98+
(
99+
"empty_list_content",
100+
[LangchainHumanMessage(content=[])],
101+
),
102+
(
103+
"non_dict_items_in_list",
104+
[LangchainHumanMessage(content=["string_item", {"type": "text", "text": "dict_item"}])],
105+
),
106+
]
107+
)
108+
def test_handles_non_dict_content_without_errors(self, name, input_messages):
109+
result = self.summarizer._construct_messages(input_messages)
110+
self.assertIsNotNone(result)
111+
112+
def test_original_message_not_modified(self):
113+
original_content: list[str | dict[Any, Any]] = [
114+
{"type": "text", "text": "Hello", "cache_control": {"type": "ephemeral"}},
115+
]
116+
message = LangchainHumanMessage(content=original_content)
117+
118+
# Store the original cache_control to verify it's not modified
119+
content_list = cast(list[dict[str, Any]], message.content)
120+
self.assertIn("cache_control", content_list[0])
121+
122+
self.summarizer._construct_messages([message])
123+
124+
# Verify original message still has cache_control
125+
content_list = cast(list[dict[str, Any]], message.content)
126+
self.assertIn("cache_control", content_list[0])
127+
self.assertEqual(content_list[0]["cache_control"], {"type": "ephemeral"})
128+
129+
def test_deep_copy_prevents_modification(self):
130+
original_content: list[str | dict[Any, Any]] = [
131+
{
132+
"type": "text",
133+
"text": "Test",
134+
"cache_control": {"type": "ephemeral"},
135+
"other_key": "value",
136+
},
137+
]
138+
message = LangchainHumanMessage(content=original_content)
139+
140+
content_list = cast(list[dict[str, Any]], message.content)
141+
original_keys = set(content_list[0].keys())
142+
143+
self.summarizer._construct_messages([message])
144+
145+
# Verify original message structure unchanged
146+
content_list = cast(list[dict[str, Any]], message.content)
147+
self.assertEqual(set(content_list[0].keys()), original_keys)
148+
self.assertIn("cache_control", content_list[0])
149+
150+
def test_preserves_other_content_properties(self):
151+
input_messages = [
152+
LangchainHumanMessage(
153+
content=[
154+
{
155+
"type": "text",
156+
"text": "Hello",
157+
"cache_control": {"type": "ephemeral"},
158+
"custom_field": "custom_value",
159+
"another_field": 123,
160+
},
161+
]
162+
)
163+
]
164+
165+
result = self.summarizer._construct_messages(input_messages)
166+
messages = result.messages[1:-1]
167+
168+
# Verify other fields are preserved
169+
content = messages[0].content[0]
170+
self.assertEqual(content["custom_field"], "custom_value")
171+
self.assertEqual(content["another_field"], 123)
172+
self.assertNotIn("cache_control", content)
173+
174+
def test_empty_messages_list(self):
175+
result = self.summarizer._construct_messages([])
176+
# Should return prompt template with just system and user prompts
177+
self.assertEqual(len(result.messages), 2)

ee/hogai/graph/root/compaction_manager.py

Lines changed: 63 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from abc import ABC, abstractmethod
33
from collections.abc import Callable, Sequence
44
from typing import TYPE_CHECKING, Any, TypeVar
5+
from uuid import uuid4
56

67
from langchain_anthropic import ChatAnthropic
78
from langchain_core.language_models import BaseChatModel
@@ -10,11 +11,13 @@
1011
HumanMessage as LangchainHumanMessage,
1112
)
1213
from langchain_core.tools import BaseTool
14+
from pydantic import BaseModel
1315

14-
from posthog.schema import AssistantMessage, AssistantToolCallMessage, HumanMessage
16+
from posthog.schema import AssistantMessage, AssistantToolCallMessage, ContextMessage, HumanMessage
1517

1618
from posthog.sync import database_sync_to_async
1719

20+
from ee.hogai.utils.helpers import find_start_message, find_start_message_idx, insert_messages_before_start
1821
from ee.hogai.utils.types import AssistantMessageUnion
1922

2023
if TYPE_CHECKING:
@@ -25,6 +28,12 @@
2528
LangchainTools = Sequence[dict[str, Any] | type | Callable | BaseTool]
2629

2730

31+
class InsertionResult(BaseModel):
32+
messages: Sequence[AssistantMessageUnion]
33+
updated_start_id: str
34+
updated_window_start_id: str
35+
36+
2837
class ConversationCompactionManager(ABC):
2938
"""
3039
Manages conversation window boundaries, message filtering, and summarization decisions.
@@ -39,27 +48,27 @@ class ConversationCompactionManager(ABC):
3948
Determines the approximate number of characters per token.
4049
"""
4150

42-
def find_window_boundary(
43-
self, messages: list[AssistantMessageUnion], max_messages: int = 10, max_tokens: int = 1000
44-
) -> str:
51+
def find_window_boundary(self, messages: Sequence[T], max_messages: int = 10, max_tokens: int = 1000) -> str | None:
4552
"""
4653
Find the optimal window start ID based on message count and token limits.
4754
Ensures the window starts at a human or assistant message.
4855
"""
4956

50-
new_window_id: str = str(messages[-1].id)
57+
new_window_id: str | None = None
5158
for message in reversed(messages):
59+
# Handle limits before assigning the window ID.
60+
max_tokens -= self._get_estimated_tokens(message)
61+
max_messages -= 1
62+
if max_tokens < 0 or max_messages < 0:
63+
break
64+
65+
# Assign the new new window ID.
5266
if message.id is not None:
5367
if isinstance(message, HumanMessage):
5468
new_window_id = message.id
5569
if isinstance(message, AssistantMessage):
5670
new_window_id = message.id
5771

58-
max_messages -= 1
59-
max_tokens -= self._get_estimated_tokens(message)
60-
if max_messages <= 0 or max_tokens <= 0:
61-
break
62-
6372
return new_window_id
6473

6574
def get_messages_in_window(self, messages: Sequence[T], window_start_id: str | None = None) -> Sequence[T]:
@@ -84,6 +93,50 @@ async def should_compact_conversation(
8493
token_count = await self._get_token_count(model, messages, tools, **kwargs)
8594
return token_count > self.CONVERSATION_WINDOW_SIZE
8695

96+
def update_window(
97+
self, messages: Sequence[T], summary_message: ContextMessage, start_id: str | None = None
98+
) -> InsertionResult:
99+
"""Finds the optimal position to insert the summary message in the conversation window."""
100+
window_start_id_candidate = self.find_window_boundary(messages, max_messages=16, max_tokens=2048)
101+
start_message = find_start_message(messages, start_id)
102+
if not start_message:
103+
raise ValueError("Start message not found")
104+
105+
start_message_copy = start_message.model_copy(deep=True)
106+
start_message_copy.id = str(uuid4())
107+
108+
# The last messages were too large to fit into the window. Copy the last human message to the start of the window.
109+
if not window_start_id_candidate:
110+
return InsertionResult(
111+
messages=[*messages, summary_message, start_message_copy],
112+
updated_start_id=start_message_copy.id,
113+
updated_window_start_id=summary_message.id,
114+
)
115+
116+
# Find the updated window
117+
start_message_idx = find_start_message_idx(messages, window_start_id_candidate)
118+
new_window = messages[start_message_idx:]
119+
120+
# If the start human message is in the window, insert the summary message before it
121+
# and update the window start.
122+
if next((m for m in new_window if m.id == start_id), None):
123+
updated_messages = insert_messages_before_start(messages, [summary_message], start_id=start_id)
124+
return InsertionResult(
125+
messages=updated_messages,
126+
updated_start_id=start_id,
127+
updated_window_start_id=window_start_id_candidate,
128+
)
129+
130+
# If the start message is not in the window, insert the summary message and human message at the start of the window.
131+
updated_messages = insert_messages_before_start(
132+
new_window, [summary_message, start_message_copy], start_id=window_start_id_candidate
133+
)
134+
return InsertionResult(
135+
messages=updated_messages,
136+
updated_start_id=start_message_copy.id,
137+
updated_window_start_id=window_start_id_candidate,
138+
)
139+
87140
def _get_estimated_tokens(self, message: AssistantMessageUnion) -> int:
88141
"""
89142
Estimate token count for a message using character/4 heuristic.

0 commit comments

Comments
 (0)