Skip to content

Commit fcf94e1

Browse files
committed
fix: remove cache control headers
1 parent 6f30cf5 commit fcf94e1

File tree

3 files changed

+199
-6
lines changed

3 files changed

+199
-6
lines changed

ee/hogai/graph/conversation_summarizer/nodes.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,7 @@ def __init__(self, team: Team, user: User):
1818
self._team = team
1919

2020
async def summarize(self, messages: Sequence[BaseMessage]) -> str:
21-
prompt = (
22-
ChatPromptTemplate.from_messages([("system", SYSTEM_PROMPT)])
23-
+ messages
24-
+ ChatPromptTemplate.from_messages([("user", USER_PROMPT)])
25-
)
21+
prompt = self._construct_messages(messages)
2622
model = self._get_model()
2723
chain = prompt | model | StrOutputParser() | self._parse_xml_tags
2824
response: str = await chain.ainvoke({}) # Do not pass config here, so the node doesn't stream
@@ -31,6 +27,13 @@ async def summarize(self, messages: Sequence[BaseMessage]) -> str:
3127
@abstractmethod
3228
def _get_model(self): ...
3329

30+
def _construct_messages(self, messages: Sequence[BaseMessage]):
31+
return (
32+
ChatPromptTemplate.from_messages([("system", SYSTEM_PROMPT)])
33+
+ messages
34+
+ ChatPromptTemplate.from_messages([("user", USER_PROMPT)])
35+
)
36+
3437
def _parse_xml_tags(self, message: str) -> str:
3538
"""
3639
Extract analysis and summary tags from a message.
@@ -54,11 +57,24 @@ def _parse_xml_tags(self, message: str) -> str:
5457
class AnthropicConversationSummarizer(ConversationSummarizer):
5558
def _get_model(self):
5659
return MaxChatAnthropic(
57-
model="claude-sonnet-4-5",
60+
model="claude-haiku-4-5",
5861
streaming=False,
5962
stream_usage=False,
6063
max_tokens=8192,
6164
disable_streaming=True,
6265
user=self._user,
6366
team=self._team,
6467
)
68+
69+
def _construct_messages(self, messages: Sequence[BaseMessage]):
70+
"""Removes cache_control headers."""
71+
messages_without_cache: list[BaseMessage] = []
72+
for message in messages:
73+
if isinstance(message.content, list):
74+
message = message.model_copy(deep=True)
75+
for content in message.content:
76+
if isinstance(content, dict) and "cache_control" in content:
77+
content.pop("cache_control")
78+
messages_without_cache.append(message)
79+
80+
return super()._construct_messages(messages_without_cache)

ee/hogai/graph/conversation_summarizer/test/__init__.py

Whitespace-only changes.
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
from typing import Any, cast
2+
3+
from posthog.test.base import BaseTest
4+
5+
from langchain_core.messages import (
6+
AIMessage as LangchainAIMessage,
7+
HumanMessage as LangchainHumanMessage,
8+
)
9+
from parameterized import parameterized
10+
11+
from ee.hogai.graph.conversation_summarizer.nodes import AnthropicConversationSummarizer
12+
13+
14+
class TestAnthropicConversationSummarizer(BaseTest):
15+
def setUp(self):
16+
super().setUp()
17+
self.summarizer = AnthropicConversationSummarizer(team=self.team, user=self.user)
18+
19+
@parameterized.expand(
20+
[
21+
(
22+
"single_message_with_cache_control",
23+
[
24+
LangchainHumanMessage(
25+
content=[
26+
{"type": "text", "text": "Hello", "cache_control": {"type": "ephemeral"}},
27+
]
28+
)
29+
],
30+
[[{"type": "text", "text": "Hello"}]],
31+
),
32+
(
33+
"multiple_items_with_cache_control",
34+
[
35+
LangchainAIMessage(
36+
content=[
37+
{"type": "text", "text": "First", "cache_control": {"type": "ephemeral"}},
38+
{"type": "text", "text": "Second", "cache_control": {"type": "ephemeral"}},
39+
]
40+
)
41+
],
42+
[[{"type": "text", "text": "First"}, {"type": "text", "text": "Second"}]],
43+
),
44+
(
45+
"mixed_items_some_with_cache_control",
46+
[
47+
LangchainHumanMessage(
48+
content=[
49+
{"type": "text", "text": "With cache", "cache_control": {"type": "ephemeral"}},
50+
{"type": "text", "text": "Without cache"},
51+
]
52+
)
53+
],
54+
[[{"type": "text", "text": "With cache"}, {"type": "text", "text": "Without cache"}]],
55+
),
56+
(
57+
"multiple_messages_with_cache_control",
58+
[
59+
LangchainHumanMessage(
60+
content=[
61+
{"type": "text", "text": "Message 1", "cache_control": {"type": "ephemeral"}},
62+
]
63+
),
64+
LangchainAIMessage(
65+
content=[
66+
{"type": "text", "text": "Message 2", "cache_control": {"type": "ephemeral"}},
67+
]
68+
),
69+
],
70+
[
71+
[{"type": "text", "text": "Message 1"}],
72+
[{"type": "text", "text": "Message 2"}],
73+
],
74+
),
75+
]
76+
)
77+
def test_removes_cache_control(self, name, input_messages, expected_contents):
78+
result = self.summarizer._construct_messages(input_messages)
79+
80+
# Extract the actual messages from the prompt template
81+
messages = result.messages[1:-1] # Skip system prompt and user prompt
82+
83+
self.assertEqual(len(messages), len(expected_contents), f"Wrong number of messages in test case: {name}")
84+
85+
for i, (message, expected_content) in enumerate(zip(messages, expected_contents)):
86+
self.assertEqual(
87+
message.content,
88+
expected_content,
89+
f"Message {i} content mismatch in test case: {name}",
90+
)
91+
92+
@parameterized.expand(
93+
[
94+
(
95+
"string_content",
96+
[LangchainHumanMessage(content="Simple string")],
97+
),
98+
(
99+
"empty_list_content",
100+
[LangchainHumanMessage(content=[])],
101+
),
102+
(
103+
"non_dict_items_in_list",
104+
[LangchainHumanMessage(content=["string_item", {"type": "text", "text": "dict_item"}])],
105+
),
106+
]
107+
)
108+
def test_handles_non_dict_content_without_errors(self, name, input_messages):
109+
result = self.summarizer._construct_messages(input_messages)
110+
self.assertIsNotNone(result)
111+
112+
def test_original_message_not_modified(self):
113+
original_content: list[str | dict[Any, Any]] = [
114+
{"type": "text", "text": "Hello", "cache_control": {"type": "ephemeral"}},
115+
]
116+
message = LangchainHumanMessage(content=original_content)
117+
118+
# Store the original cache_control to verify it's not modified
119+
content_list = cast(list[dict[str, Any]], message.content)
120+
self.assertIn("cache_control", content_list[0])
121+
122+
self.summarizer._construct_messages([message])
123+
124+
# Verify original message still has cache_control
125+
content_list = cast(list[dict[str, Any]], message.content)
126+
self.assertIn("cache_control", content_list[0])
127+
self.assertEqual(content_list[0]["cache_control"], {"type": "ephemeral"})
128+
129+
def test_deep_copy_prevents_modification(self):
130+
original_content: list[str | dict[Any, Any]] = [
131+
{
132+
"type": "text",
133+
"text": "Test",
134+
"cache_control": {"type": "ephemeral"},
135+
"other_key": "value",
136+
},
137+
]
138+
message = LangchainHumanMessage(content=original_content)
139+
140+
content_list = cast(list[dict[str, Any]], message.content)
141+
original_keys = set(content_list[0].keys())
142+
143+
self.summarizer._construct_messages([message])
144+
145+
# Verify original message structure unchanged
146+
content_list = cast(list[dict[str, Any]], message.content)
147+
self.assertEqual(set(content_list[0].keys()), original_keys)
148+
self.assertIn("cache_control", content_list[0])
149+
150+
def test_preserves_other_content_properties(self):
151+
input_messages = [
152+
LangchainHumanMessage(
153+
content=[
154+
{
155+
"type": "text",
156+
"text": "Hello",
157+
"cache_control": {"type": "ephemeral"},
158+
"custom_field": "custom_value",
159+
"another_field": 123,
160+
},
161+
]
162+
)
163+
]
164+
165+
result = self.summarizer._construct_messages(input_messages)
166+
messages = result.messages[1:-1]
167+
168+
# Verify other fields are preserved
169+
content = messages[0].content[0]
170+
self.assertEqual(content["custom_field"], "custom_value")
171+
self.assertEqual(content["another_field"], 123)
172+
self.assertNotIn("cache_control", content)
173+
174+
def test_empty_messages_list(self):
175+
result = self.summarizer._construct_messages([])
176+
# Should return prompt template with just system and user prompts
177+
self.assertEqual(len(result.messages), 2)

0 commit comments

Comments
 (0)