22from abc import ABC , abstractmethod
33from collections .abc import Callable , Sequence
44from typing import TYPE_CHECKING , Any , TypeVar
5+ from uuid import uuid4
56
67from langchain_anthropic import ChatAnthropic
78from langchain_core .language_models import BaseChatModel
1011 HumanMessage as LangchainHumanMessage ,
1112)
1213from langchain_core .tools import BaseTool
14+ from pydantic import BaseModel
1315
14- from posthog .schema import AssistantMessage , AssistantToolCallMessage , HumanMessage
16+ from posthog .schema import AssistantMessage , AssistantToolCallMessage , ContextMessage , HumanMessage
1517
1618from posthog .sync import database_sync_to_async
1719
20+ from ee .hogai .utils .helpers import find_start_message , find_start_message_idx , insert_messages_before_start
1821from ee .hogai .utils .types import AssistantMessageUnion
1922
2023if TYPE_CHECKING :
2528LangchainTools = Sequence [dict [str , Any ] | type | Callable | BaseTool ]
2629
2730
31+ class InsertionResult (BaseModel ):
32+ messages : Sequence [AssistantMessageUnion ]
33+ updated_start_id : str
34+ updated_window_start_id : str
35+
36+
2837class ConversationCompactionManager (ABC ):
2938 """
3039 Manages conversation window boundaries, message filtering, and summarization decisions.
@@ -39,27 +48,27 @@ class ConversationCompactionManager(ABC):
3948 Determines the approximate number of characters per token.
4049 """
4150
42- def find_window_boundary (
43- self , messages : list [AssistantMessageUnion ], max_messages : int = 10 , max_tokens : int = 1000
44- ) -> str :
51+ def find_window_boundary (self , messages : Sequence [T ], max_messages : int = 10 , max_tokens : int = 1000 ) -> str | None :
4552 """
4653 Find the optimal window start ID based on message count and token limits.
4754 Ensures the window starts at a human or assistant message.
4855 """
4956
50- new_window_id : str = str ( messages [ - 1 ]. id )
57+ new_window_id : str | None = None
5158 for message in reversed (messages ):
59+ # Handle limits before assigning the window ID.
60+ max_tokens -= self ._get_estimated_tokens (message )
61+ max_messages -= 1
62+ if max_tokens < 0 or max_messages < 0 :
63+ break
64+
65+ # Assign the new new window ID.
5266 if message .id is not None :
5367 if isinstance (message , HumanMessage ):
5468 new_window_id = message .id
5569 if isinstance (message , AssistantMessage ):
5670 new_window_id = message .id
5771
58- max_messages -= 1
59- max_tokens -= self ._get_estimated_tokens (message )
60- if max_messages <= 0 or max_tokens <= 0 :
61- break
62-
6372 return new_window_id
6473
6574 def get_messages_in_window (self , messages : Sequence [T ], window_start_id : str | None = None ) -> Sequence [T ]:
@@ -84,6 +93,50 @@ async def should_compact_conversation(
8493 token_count = await self ._get_token_count (model , messages , tools , ** kwargs )
8594 return token_count > self .CONVERSATION_WINDOW_SIZE
8695
96+ def update_window (
97+ self , messages : Sequence [T ], summary_message : ContextMessage , start_id : str | None = None
98+ ) -> InsertionResult :
99+ """Finds the optimal position to insert the summary message in the conversation window."""
100+ window_start_id_candidate = self .find_window_boundary (messages , max_messages = 16 , max_tokens = 2048 )
101+ start_message = find_start_message (messages , start_id )
102+ if not start_message :
103+ raise ValueError ("Start message not found" )
104+
105+ start_message_copy = start_message .model_copy (deep = True )
106+ start_message_copy .id = str (uuid4 ())
107+
108+ # The last messages were too large to fit into the window. Copy the last human message to the start of the window.
109+ if not window_start_id_candidate :
110+ return InsertionResult (
111+ messages = [* messages , summary_message , start_message_copy ],
112+ updated_start_id = start_message_copy .id ,
113+ updated_window_start_id = summary_message .id ,
114+ )
115+
116+ # Find the updated window
117+ start_message_idx = find_start_message_idx (messages , window_start_id_candidate )
118+ new_window = messages [start_message_idx :]
119+
120+ # If the start human message is in the window, insert the summary message before it
121+ # and update the window start.
122+ if next ((m for m in new_window if m .id == start_id ), None ):
123+ updated_messages = insert_messages_before_start (messages , [summary_message ], start_id = start_id )
124+ return InsertionResult (
125+ messages = updated_messages ,
126+ updated_start_id = start_id ,
127+ updated_window_start_id = window_start_id_candidate ,
128+ )
129+
130+ # If the start message is not in the window, insert the summary message and human message at the start of the window.
131+ updated_messages = insert_messages_before_start (
132+ new_window , [summary_message , start_message_copy ], start_id = window_start_id_candidate
133+ )
134+ return InsertionResult (
135+ messages = updated_messages ,
136+ updated_start_id = start_message_copy .id ,
137+ updated_window_start_id = window_start_id_candidate ,
138+ )
139+
87140 def _get_estimated_tokens (self , message : AssistantMessageUnion ) -> int :
88141 """
89142 Estimate token count for a message using character/4 heuristic.
0 commit comments