Skip to content

Index discrete memories in long-term storage, add a new memory prompt REST API endpoint, and YAML config #9

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ PORT=8000

# Memory settings
LONG_TERM_MEMORY=true
MAX_WINDOW_SIZE=12
WINDOW_SIZE=12
GENERATION_MODEL=gpt-4o-mini
EMBEDDING_MODEL=text-embedding-3-small

Expand Down
174 changes: 129 additions & 45 deletions agent_memory_server/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Literal

from fastapi import APIRouter, Depends, HTTPException
from mcp.server.fastmcp.prompts import base
from mcp.types import TextContent

from agent_memory_server import long_term_memory, messages
from agent_memory_server.config import settings
Expand All @@ -9,48 +9,44 @@
from agent_memory_server.logging import get_logger
from agent_memory_server.models import (
AckResponse,
CreateLongTermMemoryPayload,
CreateLongTermMemoryRequest,
GetSessionsQuery,
LongTermMemoryResultsResponse,
SearchPayload,
MemoryPromptRequest,
MemoryPromptResponse,
ModelNameLiteral,
SearchRequest,
SessionListResponse,
SessionMemory,
SessionMemoryResponse,
SystemMessage,
)
from agent_memory_server.utils.redis import get_redis_conn


logger = get_logger(__name__)

ModelNameLiteral = Literal[
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-4",
"gpt-4-32k",
"gpt-4o",
"gpt-4o-mini",
"o1",
"o1-mini",
"o3-mini",
"text-embedding-ada-002",
"text-embedding-3-small",
"text-embedding-3-large",
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307",
"claude-3-5-sonnet-20240620",
"claude-3-7-sonnet-20250219",
"claude-3-5-sonnet-20241022",
"claude-3-5-haiku-20241022",
"claude-3-7-sonnet-latest",
"claude-3-5-sonnet-latest",
"claude-3-5-haiku-latest",
"claude-3-opus-latest",
]

router = APIRouter()


def _get_effective_window_size(
window_size: int,
context_window_max: int | None,
model_name: ModelNameLiteral | None,
) -> int:
# If context_window_max is explicitly provided, use that
if context_window_max is not None:
effective_window_size = min(window_size, context_window_max)
# If model_name is provided, get its max_tokens from our config
elif model_name is not None:
model_config = get_model_config(model_name)
effective_window_size = min(window_size, model_config.max_tokens)
# Otherwise use the default window_size
else:
effective_window_size = window_size
return effective_window_size


@router.get("/sessions/", response_model=SessionListResponse)
async def list_sessions(
options: GetSessionsQuery = Depends(),
Expand Down Expand Up @@ -103,17 +99,11 @@ async def get_session_memory(
Conversation history and context
"""
redis = await get_redis_conn()

# If context_window_max is explicitly provided, use that
if context_window_max is not None:
effective_window_size = min(window_size, context_window_max)
# If model_name is provided, get its max_tokens from our config
elif model_name is not None:
model_config = get_model_config(model_name)
effective_window_size = min(window_size, model_config.max_tokens)
# Otherwise use the default window_size
else:
effective_window_size = window_size
effective_window_size = _get_effective_window_size(
window_size=window_size,
context_window_max=context_window_max,
model_name=model_name,
)

session = await messages.get_session_memory(
redis=redis,
Expand Down Expand Up @@ -181,7 +171,7 @@ async def delete_session_memory(

@router.post("/long-term-memory", response_model=AckResponse)
async def create_long_term_memory(
payload: CreateLongTermMemoryPayload,
payload: CreateLongTermMemoryRequest,
background_tasks=Depends(get_background_tasks),
):
"""
Expand All @@ -205,7 +195,7 @@ async def create_long_term_memory(


@router.post("/long-term-memory/search", response_model=LongTermMemoryResultsResponse)
async def search_long_term_memory(payload: SearchPayload):
async def search_long_term_memory(payload: SearchRequest):
"""
Run a semantic search on long-term memory with filtering options.

Expand All @@ -215,11 +205,11 @@ async def search_long_term_memory(payload: SearchPayload):
Returns:
List of search results
"""
redis = await get_redis_conn()

if not settings.long_term_memory:
raise HTTPException(status_code=400, detail="Long-term memory is disabled")

redis = await get_redis_conn()

# Extract filter objects from the payload
filters = payload.get_filters()

Expand All @@ -236,3 +226,97 @@ async def search_long_term_memory(payload: SearchPayload):

# Pass text, redis, and filter objects to the search function
return await long_term_memory.search_long_term_memories(**kwargs)


@router.post("/memory-prompt", response_model=MemoryPromptResponse)
async def memory_prompt(params: MemoryPromptRequest) -> MemoryPromptResponse:
"""
Hydrate a user query with memory context and return a prompt
ready to send to an LLM.

`query` is the input text that the caller of this API wants to use to find
relevant context. If `session_id` is provided and matches an existing
session, the resulting prompt will include those messages as the immediate
history of messages leading to a message containing `query`.

If `long_term_search_payload` is provided, the resulting prompt will include
relevant long-term memories found via semantic search with the options
provided in the payload.

Args:
params: MemoryPromptRequest

Returns:
List of messages to send to an LLM, hydrated with relevant memory context
"""
if not params.session and not params.long_term_search:
raise HTTPException(
status_code=400,
detail="Either session or long_term_search must be provided",
)

redis = await get_redis_conn()
_messages = []

if params.session:
effective_window_size = _get_effective_window_size(
window_size=params.session.window_size,
context_window_max=params.session.context_window_max,
model_name=params.session.model_name,
)
session_memory = await messages.get_session_memory(
redis=redis,
session_id=params.session.session_id,
window_size=effective_window_size,
namespace=params.session.namespace,
)

if session_memory:
if session_memory.context:
# TODO: Weird to use MCP types here?
_messages.append(
SystemMessage(
content=TextContent(
type="text",
text=f"## A summary of the conversation so far:\n{session_memory.context}",
),
)
)
# Ignore past system messages as the latest context may have changed
for msg in session_memory.messages:
if msg.role == "user":
msg_class = base.UserMessage
else:
msg_class = base.AssistantMessage
_messages.append(
msg_class(
content=TextContent(type="text", text=msg.content),
)
)

if params.long_term_search:
# TODO: Exclude session messages if we already included them from session memory
long_term_memories = await search_long_term_memory(
params.long_term_search,
)

if long_term_memories.total > 0:
long_term_memories_text = "\n".join(
[f"- {m.text}" for m in long_term_memories.memories]
)
_messages.append(
SystemMessage(
content=TextContent(
type="text",
text=f"## Long term memories related to the user's query\n {long_term_memories_text}",
),
)
)

_messages.append(
base.UserMessage(
content=TextContent(type="text", text=params.query),
)
)

return MemoryPromptResponse(messages=_messages)
Loading