Skip to content

Commit 6f30cf5

Browse files
committed
fix: improved compaction
1 parent b4d71e2 commit 6f30cf5

File tree

6 files changed

+645
-21
lines changed

6 files changed

+645
-21
lines changed

ee/hogai/graph/root/compaction_manager.py

Lines changed: 63 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from abc import ABC, abstractmethod
33
from collections.abc import Callable, Sequence
44
from typing import TYPE_CHECKING, Any, TypeVar
5+
from uuid import uuid4
56

67
from langchain_anthropic import ChatAnthropic
78
from langchain_core.language_models import BaseChatModel
@@ -10,11 +11,13 @@
1011
HumanMessage as LangchainHumanMessage,
1112
)
1213
from langchain_core.tools import BaseTool
14+
from pydantic import BaseModel
1315

14-
from posthog.schema import AssistantMessage, AssistantToolCallMessage, HumanMessage
16+
from posthog.schema import AssistantMessage, AssistantToolCallMessage, ContextMessage, HumanMessage
1517

1618
from posthog.sync import database_sync_to_async
1719

20+
from ee.hogai.utils.helpers import find_start_message, find_start_message_idx, insert_messages_before_start
1821
from ee.hogai.utils.types import AssistantMessageUnion
1922

2023
if TYPE_CHECKING:
@@ -25,6 +28,12 @@
2528
LangchainTools = Sequence[dict[str, Any] | type | Callable | BaseTool]
2629

2730

31+
class InsertionResult(BaseModel):
32+
messages: Sequence[AssistantMessageUnion]
33+
updated_start_id: str
34+
updated_window_start_id: str
35+
36+
2837
class ConversationCompactionManager(ABC):
2938
"""
3039
Manages conversation window boundaries, message filtering, and summarization decisions.
@@ -39,27 +48,27 @@ class ConversationCompactionManager(ABC):
3948
Determines the approximate number of characters per token.
4049
"""
4150

42-
def find_window_boundary(
43-
self, messages: list[AssistantMessageUnion], max_messages: int = 10, max_tokens: int = 1000
44-
) -> str:
51+
def find_window_boundary(self, messages: Sequence[T], max_messages: int = 10, max_tokens: int = 1000) -> str | None:
4552
"""
4653
Find the optimal window start ID based on message count and token limits.
4754
Ensures the window starts at a human or assistant message.
4855
"""
4956

50-
new_window_id: str = str(messages[-1].id)
57+
new_window_id: str | None = None
5158
for message in reversed(messages):
59+
# Handle limits before assigning the window ID.
60+
max_tokens -= self._get_estimated_tokens(message)
61+
max_messages -= 1
62+
if max_tokens < 0 or max_messages < 0:
63+
break
64+
65+
# Assign the new new window ID.
5266
if message.id is not None:
5367
if isinstance(message, HumanMessage):
5468
new_window_id = message.id
5569
if isinstance(message, AssistantMessage):
5670
new_window_id = message.id
5771

58-
max_messages -= 1
59-
max_tokens -= self._get_estimated_tokens(message)
60-
if max_messages <= 0 or max_tokens <= 0:
61-
break
62-
6372
return new_window_id
6473

6574
def get_messages_in_window(self, messages: Sequence[T], window_start_id: str | None = None) -> Sequence[T]:
@@ -84,6 +93,50 @@ async def should_compact_conversation(
8493
token_count = await self._get_token_count(model, messages, tools, **kwargs)
8594
return token_count > self.CONVERSATION_WINDOW_SIZE
8695

96+
def update_window(
97+
self, messages: Sequence[T], summary_message: ContextMessage, start_id: str | None = None
98+
) -> InsertionResult:
99+
"""Finds the optimal position to insert the summary message in the conversation window."""
100+
window_start_id_candidate = self.find_window_boundary(messages, max_messages=16, max_tokens=2048)
101+
start_message = find_start_message(messages, start_id)
102+
if not start_message:
103+
raise ValueError("Start message not found")
104+
105+
start_message_copy = start_message.model_copy(deep=True)
106+
start_message_copy.id = str(uuid4())
107+
108+
# The last messages were too large to fit into the window. Copy the last human message to the start of the window.
109+
if not window_start_id_candidate:
110+
return InsertionResult(
111+
messages=[*messages, summary_message, start_message_copy],
112+
updated_start_id=start_message_copy.id,
113+
updated_window_start_id=summary_message.id,
114+
)
115+
116+
# Find the updated window
117+
start_message_idx = find_start_message_idx(messages, window_start_id_candidate)
118+
new_window = messages[start_message_idx:]
119+
120+
# If the start human message is in the window, insert the summary message before it
121+
# and update the window start.
122+
if next((m for m in new_window if m.id == start_id), None):
123+
updated_messages = insert_messages_before_start(messages, [summary_message], start_id=start_id)
124+
return InsertionResult(
125+
messages=updated_messages,
126+
updated_start_id=start_id,
127+
updated_window_start_id=window_start_id_candidate,
128+
)
129+
130+
# If the start message is not in the window, insert the summary message and human message at the start of the window.
131+
updated_messages = insert_messages_before_start(
132+
new_window, [summary_message, start_message_copy], start_id=window_start_id_candidate
133+
)
134+
return InsertionResult(
135+
messages=updated_messages,
136+
updated_start_id=start_message_copy.id,
137+
updated_window_start_id=window_start_id_candidate,
138+
)
139+
87140
def _get_estimated_tokens(self, message: AssistantMessageUnion) -> int:
88141
"""
89142
Estimate token count for a message using character/4 heuristic.

ee/hogai/graph/root/nodes.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from ee.hogai.llm import MaxChatAnthropic
3636
from ee.hogai.tool import CONTEXTUAL_TOOL_NAME_TO_TOOL, ToolMessagesArtifact
3737
from ee.hogai.utils.anthropic import add_cache_control, convert_to_anthropic_messages, normalize_ai_anthropic_message
38-
from ee.hogai.utils.helpers import convert_tool_messages_to_dict, insert_messages_before_start
38+
from ee.hogai.utils.helpers import convert_tool_messages_to_dict
3939
from ee.hogai.utils.prompt import format_prompt_string
4040
from ee.hogai.utils.types import (
4141
AssistantMessageUnion,
@@ -130,6 +130,7 @@ async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAs
130130
messages_to_replace or state.messages, state.root_conversation_start_id, state.root_tool_calls_count
131131
)
132132
window_id = state.root_conversation_start_id
133+
start_id = state.start_id
133134

134135
# Summarize the conversation if it's too long.
135136
if await self._window_manager.should_compact_conversation(
@@ -144,12 +145,14 @@ async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAs
144145
)
145146

146147
# Insert the summary message before the last human message
147-
messages_to_replace = insert_messages_before_start(
148-
messages_to_replace or state.messages, [summary_message], start_id=state.start_id
148+
insertion_result = self._window_manager.update_window(
149+
messages_to_replace or state.messages, summary_message, start_id=start_id
149150
)
151+
window_id = insertion_result.updated_window_start_id
152+
start_id = insertion_result.updated_start_id
153+
messages_to_replace = insertion_result.messages
150154

151-
# Update window
152-
window_id = self._window_manager.find_window_boundary(messages_to_replace)
155+
# Update the window
153156
langchain_messages = self._construct_messages(messages_to_replace, window_id, state.root_tool_calls_count)
154157

155158
system_prompts = ChatPromptTemplate.from_messages(
@@ -174,7 +177,11 @@ async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAs
174177
if messages_to_replace:
175178
new_messages = ReplaceMessages([*messages_to_replace, assistant_message])
176179

177-
return PartialAssistantState(root_conversation_start_id=window_id, messages=new_messages)
180+
return PartialAssistantState(
181+
messages=new_messages,
182+
root_conversation_start_id=window_id,
183+
start_id=start_id,
184+
)
178185

179186
async def get_reasoning_message(
180187
self, input: BaseState, default_message: Optional[str] = None

0 commit comments

Comments
 (0)