Skip to content

Commit ce537c4

Browse files
committed
checkpoint
1 parent 159f7d4 commit ce537c4

File tree

10 files changed

+661
-186
lines changed

10 files changed

+661
-186
lines changed

.env.example

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ PORT=8000
66

77
# Memory settings
88
LONG_TERM_MEMORY=true
9-
MAX_WINDOW_SIZE=12
9+
WINDOW_SIZE=12
1010
GENERATION_MODEL=gpt-4o-mini
1111
EMBEDDING_MODEL=text-embedding-3-small
1212

agent_memory_server/api.py

Lines changed: 129 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from typing import Literal
2-
31
from fastapi import APIRouter, Depends, HTTPException
2+
from mcp.server.fastmcp.prompts import base
3+
from mcp.types import TextContent
44

55
from agent_memory_server import long_term_memory, messages
66
from agent_memory_server.config import settings
@@ -9,48 +9,44 @@
99
from agent_memory_server.logging import get_logger
1010
from agent_memory_server.models import (
1111
AckResponse,
12-
CreateLongTermMemoryPayload,
12+
CreateLongTermMemoryRequest,
1313
GetSessionsQuery,
1414
LongTermMemoryResultsResponse,
15-
SearchPayload,
15+
MemoryPromptRequest,
16+
MemoryPromptResponse,
17+
ModelNameLiteral,
18+
SearchRequest,
1619
SessionListResponse,
1720
SessionMemory,
1821
SessionMemoryResponse,
22+
SystemMessage,
1923
)
2024
from agent_memory_server.utils.redis import get_redis_conn
2125

2226

2327
logger = get_logger(__name__)
2428

25-
ModelNameLiteral = Literal[
26-
"gpt-3.5-turbo",
27-
"gpt-3.5-turbo-16k",
28-
"gpt-4",
29-
"gpt-4-32k",
30-
"gpt-4o",
31-
"gpt-4o-mini",
32-
"o1",
33-
"o1-mini",
34-
"o3-mini",
35-
"text-embedding-ada-002",
36-
"text-embedding-3-small",
37-
"text-embedding-3-large",
38-
"claude-3-opus-20240229",
39-
"claude-3-sonnet-20240229",
40-
"claude-3-haiku-20240307",
41-
"claude-3-5-sonnet-20240620",
42-
"claude-3-7-sonnet-20250219",
43-
"claude-3-5-sonnet-20241022",
44-
"claude-3-5-haiku-20241022",
45-
"claude-3-7-sonnet-latest",
46-
"claude-3-5-sonnet-latest",
47-
"claude-3-5-haiku-latest",
48-
"claude-3-opus-latest",
49-
]
50-
5129
router = APIRouter()
5230

5331

32+
def _get_effective_window_size(
33+
window_size: int,
34+
context_window_max: int | None,
35+
model_name: ModelNameLiteral | None,
36+
) -> int:
37+
# If context_window_max is explicitly provided, use that
38+
if context_window_max is not None:
39+
effective_window_size = min(window_size, context_window_max)
40+
# If model_name is provided, get its max_tokens from our config
41+
elif model_name is not None:
42+
model_config = get_model_config(model_name)
43+
effective_window_size = min(window_size, model_config.max_tokens)
44+
# Otherwise use the default window_size
45+
else:
46+
effective_window_size = window_size
47+
return effective_window_size
48+
49+
5450
@router.get("/sessions/", response_model=SessionListResponse)
5551
async def list_sessions(
5652
options: GetSessionsQuery = Depends(),
@@ -103,17 +99,11 @@ async def get_session_memory(
10399
Conversation history and context
104100
"""
105101
redis = await get_redis_conn()
106-
107-
# If context_window_max is explicitly provided, use that
108-
if context_window_max is not None:
109-
effective_window_size = min(window_size, context_window_max)
110-
# If model_name is provided, get its max_tokens from our config
111-
elif model_name is not None:
112-
model_config = get_model_config(model_name)
113-
effective_window_size = min(window_size, model_config.max_tokens)
114-
# Otherwise use the default window_size
115-
else:
116-
effective_window_size = window_size
102+
effective_window_size = _get_effective_window_size(
103+
window_size=window_size,
104+
context_window_max=context_window_max,
105+
model_name=model_name,
106+
)
117107

118108
session = await messages.get_session_memory(
119109
redis=redis,
@@ -181,7 +171,7 @@ async def delete_session_memory(
181171

182172
@router.post("/long-term-memory", response_model=AckResponse)
183173
async def create_long_term_memory(
184-
payload: CreateLongTermMemoryPayload,
174+
payload: CreateLongTermMemoryRequest,
185175
background_tasks=Depends(get_background_tasks),
186176
):
187177
"""
@@ -205,7 +195,7 @@ async def create_long_term_memory(
205195

206196

207197
@router.post("/long-term-memory/search", response_model=LongTermMemoryResultsResponse)
208-
async def search_long_term_memory(payload: SearchPayload):
198+
async def search_long_term_memory(payload: SearchRequest):
209199
"""
210200
Run a semantic search on long-term memory with filtering options.
211201
@@ -215,11 +205,11 @@ async def search_long_term_memory(payload: SearchPayload):
215205
Returns:
216206
List of search results
217207
"""
218-
redis = await get_redis_conn()
219-
220208
if not settings.long_term_memory:
221209
raise HTTPException(status_code=400, detail="Long-term memory is disabled")
222210

211+
redis = await get_redis_conn()
212+
223213
# Extract filter objects from the payload
224214
filters = payload.get_filters()
225215

@@ -236,3 +226,97 @@ async def search_long_term_memory(payload: SearchPayload):
236226

237227
# Pass text, redis, and filter objects to the search function
238228
return await long_term_memory.search_long_term_memories(**kwargs)
229+
230+
231+
@router.post("/memory-prompt", response_model=MemoryPromptResponse)
232+
async def memory_prompt(params: MemoryPromptRequest) -> MemoryPromptResponse:
233+
"""
234+
Hydrate a user query with memory context and return a prompt
235+
ready to send to an LLM.
236+
237+
`query` is the input text that the caller of this API wants to use to find
238+
relevant context. If `session_id` is provided and matches an existing
239+
session, the resulting prompt will include those messages as the immediate
240+
history of messages leading to a message containing `query`.
241+
242+
If `long_term_search_payload` is provided, the resulting prompt will include
243+
relevant long-term memories found via semantic search with the options
244+
provided in the payload.
245+
246+
Args:
247+
params: MemoryPromptRequest
248+
249+
Returns:
250+
List of messages to send to an LLM, hydrated with relevant memory context
251+
"""
252+
if not params.session and not params.long_term_search:
253+
raise HTTPException(
254+
status_code=400,
255+
detail="Either session or long_term_search must be provided",
256+
)
257+
258+
redis = await get_redis_conn()
259+
_messages = []
260+
261+
if params.session:
262+
effective_window_size = _get_effective_window_size(
263+
window_size=params.session.window_size,
264+
context_window_max=params.session.context_window_max,
265+
model_name=params.session.model_name,
266+
)
267+
session_memory = await messages.get_session_memory(
268+
redis=redis,
269+
session_id=params.session.session_id,
270+
window_size=effective_window_size,
271+
namespace=params.session.namespace,
272+
)
273+
274+
if session_memory:
275+
if session_memory.context:
276+
# TODO: Weird to use MCP types here?
277+
_messages.append(
278+
SystemMessage(
279+
content=TextContent(
280+
type="text",
281+
text=f"## A summary of the conversation so far\n{session_memory.context}",
282+
),
283+
)
284+
)
285+
# Ignore past system messages as the latest context may have changed
286+
for msg in session_memory.messages:
287+
if msg.role == "user":
288+
msg_class = base.UserMessage
289+
else:
290+
msg_class = base.AssistantMessage
291+
_messages.append(
292+
msg_class(
293+
content=TextContent(type="text", text=msg.content),
294+
)
295+
)
296+
297+
if params.long_term_search:
298+
# TODO: Exclude session messages if we already included them from session memory
299+
long_term_memories = await search_long_term_memory(
300+
params.long_term_search,
301+
)
302+
303+
if long_term_memories.total > 0:
304+
long_term_memories_text = "\n".join(
305+
[f"- {m.text}" for m in long_term_memories.memories]
306+
)
307+
_messages.append(
308+
SystemMessage(
309+
content=TextContent(
310+
type="text",
311+
text=f"## Long term memories related to the user's query\n {long_term_memories_text}",
312+
),
313+
)
314+
)
315+
316+
_messages.append(
317+
base.UserMessage(
318+
content=TextContent(type="text", text=params.query),
319+
)
320+
)
321+
322+
return MemoryPromptResponse(messages=_messages)

agent_memory_server/client/api.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@
2121
)
2222
from agent_memory_server.models import (
2323
AckResponse,
24-
CreateLongTermMemoryPayload,
24+
CreateLongTermMemoryRequest,
2525
HealthCheckResponse,
2626
LongTermMemory,
2727
LongTermMemoryResults,
28-
SearchPayload,
28+
SearchRequest,
2929
SessionListResponse,
3030
SessionMemory,
3131
SessionMemoryResponse,
@@ -256,7 +256,7 @@ async def create_long_term_memory(
256256
if memory.namespace is None:
257257
memory.namespace = self.config.default_namespace
258258

259-
payload = CreateLongTermMemoryPayload(memories=memories)
259+
payload = CreateLongTermMemoryRequest(memories=memories)
260260
response = await self._client.post(
261261
"/long-term-memory", json=payload.model_dump(exclude_none=True)
262262
)
@@ -322,7 +322,7 @@ async def search_long_term_memory(
322322
if namespace is None and self.config.default_namespace is not None:
323323
namespace = Namespace(eq=self.config.default_namespace)
324324

325-
payload = SearchPayload(
325+
payload = SearchRequest(
326326
text=text,
327327
session_id=session_id,
328328
namespace=namespace,

agent_memory_server/config.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,28 @@
11
import os
22
from typing import Literal
33

4+
import yaml
45
from dotenv import load_dotenv
56
from pydantic_settings import BaseSettings
67

78

89
load_dotenv()
910

1011

12+
def load_yaml_settings():
13+
config_path = os.getenv("APP_CONFIG_FILE", "config.yaml")
14+
if os.path.exists(config_path):
15+
with open(config_path) as f:
16+
return yaml.safe_load(f) or {}
17+
return {}
18+
19+
1120
class Settings(BaseSettings):
1221
redis_url: str = "redis://localhost:6379"
1322
long_term_memory: bool = True
1423
window_size: int = 20
15-
openai_api_key: str = os.getenv("OPENAI_API_KEY", "")
16-
anthropic_api_key: str = os.getenv("ANTHROPIC_API_KEY", "")
24+
openai_api_key: str | None = None
25+
anthropic_api_key: str | None = None
1726
generation_model: str = "gpt-4o-mini"
1827
embedding_model: str = "text-embedding-3-small"
1928
port: int = 8000
@@ -40,5 +49,11 @@ class Settings(BaseSettings):
4049
# Other Application settings
4150
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
4251

52+
class Config:
53+
env_file = ".env"
54+
env_file_encoding = "utf-8"
55+
4356

44-
settings = Settings()
57+
# Load YAML config first, then let env vars override
58+
yaml_settings = load_yaml_settings()
59+
settings = Settings(**yaml_settings)

0 commit comments

Comments
 (0)