diff --git a/.env.example b/.env.example index b3a33cb..0acfee4 100644 --- a/.env.example +++ b/.env.example @@ -6,7 +6,7 @@ PORT=8000 # Memory settings LONG_TERM_MEMORY=true -MAX_WINDOW_SIZE=12 +WINDOW_SIZE=12 GENERATION_MODEL=gpt-4o-mini EMBEDDING_MODEL=text-embedding-3-small diff --git a/agent_memory_server/api.py b/agent_memory_server/api.py index a957c7a..c6ff659 100644 --- a/agent_memory_server/api.py +++ b/agent_memory_server/api.py @@ -1,6 +1,6 @@ -from typing import Literal - from fastapi import APIRouter, Depends, HTTPException +from mcp.server.fastmcp.prompts import base +from mcp.types import TextContent from agent_memory_server import long_term_memory, messages from agent_memory_server.config import settings @@ -9,48 +9,44 @@ from agent_memory_server.logging import get_logger from agent_memory_server.models import ( AckResponse, - CreateLongTermMemoryPayload, + CreateLongTermMemoryRequest, GetSessionsQuery, LongTermMemoryResultsResponse, - SearchPayload, + MemoryPromptRequest, + MemoryPromptResponse, + ModelNameLiteral, + SearchRequest, SessionListResponse, SessionMemory, SessionMemoryResponse, + SystemMessage, ) from agent_memory_server.utils.redis import get_redis_conn logger = get_logger(__name__) -ModelNameLiteral = Literal[ - "gpt-3.5-turbo", - "gpt-3.5-turbo-16k", - "gpt-4", - "gpt-4-32k", - "gpt-4o", - "gpt-4o-mini", - "o1", - "o1-mini", - "o3-mini", - "text-embedding-ada-002", - "text-embedding-3-small", - "text-embedding-3-large", - "claude-3-opus-20240229", - "claude-3-sonnet-20240229", - "claude-3-haiku-20240307", - "claude-3-5-sonnet-20240620", - "claude-3-7-sonnet-20250219", - "claude-3-5-sonnet-20241022", - "claude-3-5-haiku-20241022", - "claude-3-7-sonnet-latest", - "claude-3-5-sonnet-latest", - "claude-3-5-haiku-latest", - "claude-3-opus-latest", -] - router = APIRouter() +def _get_effective_window_size( + window_size: int, + context_window_max: int | None, + model_name: ModelNameLiteral | None, +) -> int: + # If context_window_max is explicitly provided, use that + if context_window_max is not None: + effective_window_size = min(window_size, context_window_max) + # If model_name is provided, get its max_tokens from our config + elif model_name is not None: + model_config = get_model_config(model_name) + effective_window_size = min(window_size, model_config.max_tokens) + # Otherwise use the default window_size + else: + effective_window_size = window_size + return effective_window_size + + @router.get("/sessions/", response_model=SessionListResponse) async def list_sessions( options: GetSessionsQuery = Depends(), @@ -103,17 +99,11 @@ async def get_session_memory( Conversation history and context """ redis = await get_redis_conn() - - # If context_window_max is explicitly provided, use that - if context_window_max is not None: - effective_window_size = min(window_size, context_window_max) - # If model_name is provided, get its max_tokens from our config - elif model_name is not None: - model_config = get_model_config(model_name) - effective_window_size = min(window_size, model_config.max_tokens) - # Otherwise use the default window_size - else: - effective_window_size = window_size + effective_window_size = _get_effective_window_size( + window_size=window_size, + context_window_max=context_window_max, + model_name=model_name, + ) session = await messages.get_session_memory( redis=redis, @@ -181,7 +171,7 @@ async def delete_session_memory( @router.post("/long-term-memory", response_model=AckResponse) async def create_long_term_memory( - payload: CreateLongTermMemoryPayload, + payload: CreateLongTermMemoryRequest, background_tasks=Depends(get_background_tasks), ): """ @@ -205,7 +195,7 @@ async def create_long_term_memory( @router.post("/long-term-memory/search", response_model=LongTermMemoryResultsResponse) -async def search_long_term_memory(payload: SearchPayload): +async def search_long_term_memory(payload: SearchRequest): """ Run a semantic search on long-term memory with filtering options. @@ -215,11 +205,11 @@ async def search_long_term_memory(payload: SearchPayload): Returns: List of search results """ - redis = await get_redis_conn() - if not settings.long_term_memory: raise HTTPException(status_code=400, detail="Long-term memory is disabled") + redis = await get_redis_conn() + # Extract filter objects from the payload filters = payload.get_filters() @@ -236,3 +226,97 @@ async def search_long_term_memory(payload: SearchPayload): # Pass text, redis, and filter objects to the search function return await long_term_memory.search_long_term_memories(**kwargs) + + +@router.post("/memory-prompt", response_model=MemoryPromptResponse) +async def memory_prompt(params: MemoryPromptRequest) -> MemoryPromptResponse: + """ + Hydrate a user query with memory context and return a prompt + ready to send to an LLM. + + `query` is the input text that the caller of this API wants to use to find + relevant context. If `session_id` is provided and matches an existing + session, the resulting prompt will include those messages as the immediate + history of messages leading to a message containing `query`. + + If `long_term_search_payload` is provided, the resulting prompt will include + relevant long-term memories found via semantic search with the options + provided in the payload. + + Args: + params: MemoryPromptRequest + + Returns: + List of messages to send to an LLM, hydrated with relevant memory context + """ + if not params.session and not params.long_term_search: + raise HTTPException( + status_code=400, + detail="Either session or long_term_search must be provided", + ) + + redis = await get_redis_conn() + _messages = [] + + if params.session: + effective_window_size = _get_effective_window_size( + window_size=params.session.window_size, + context_window_max=params.session.context_window_max, + model_name=params.session.model_name, + ) + session_memory = await messages.get_session_memory( + redis=redis, + session_id=params.session.session_id, + window_size=effective_window_size, + namespace=params.session.namespace, + ) + + if session_memory: + if session_memory.context: + # TODO: Weird to use MCP types here? + _messages.append( + SystemMessage( + content=TextContent( + type="text", + text=f"## A summary of the conversation so far:\n{session_memory.context}", + ), + ) + ) + # Ignore past system messages as the latest context may have changed + for msg in session_memory.messages: + if msg.role == "user": + msg_class = base.UserMessage + else: + msg_class = base.AssistantMessage + _messages.append( + msg_class( + content=TextContent(type="text", text=msg.content), + ) + ) + + if params.long_term_search: + # TODO: Exclude session messages if we already included them from session memory + long_term_memories = await search_long_term_memory( + params.long_term_search, + ) + + if long_term_memories.total > 0: + long_term_memories_text = "\n".join( + [f"- {m.text}" for m in long_term_memories.memories] + ) + _messages.append( + SystemMessage( + content=TextContent( + type="text", + text=f"## Long term memories related to the user's query\n {long_term_memories_text}", + ), + ) + ) + + _messages.append( + base.UserMessage( + content=TextContent(type="text", text=params.query), + ) + ) + + return MemoryPromptResponse(messages=_messages) diff --git a/agent_memory_server/client/api.py b/agent_memory_server/client/api.py index b68d138..bf5b25e 100644 --- a/agent_memory_server/client/api.py +++ b/agent_memory_server/client/api.py @@ -21,13 +21,16 @@ ) from agent_memory_server.models import ( AckResponse, - CreateLongTermMemoryPayload, + CreateLongTermMemoryRequest, HealthCheckResponse, LongTermMemory, LongTermMemoryResults, - SearchPayload, + MemoryPromptRequest, + MemoryPromptResponse, + SearchRequest, SessionListResponse, SessionMemory, + SessionMemoryRequest, SessionMemoryResponse, ) @@ -129,8 +132,8 @@ async def list_sessions( SessionListResponse containing session IDs and total count """ params = { - "limit": limit, - "offset": offset, + "limit": str(limit), + "offset": str(offset), } if namespace is not None: params["namespace"] = namespace @@ -256,7 +259,7 @@ async def create_long_term_memory( if memory.namespace is None: memory.namespace = self.config.default_namespace - payload = CreateLongTermMemoryPayload(memories=memories) + payload = CreateLongTermMemoryRequest(memories=memories) response = await self._client.post( "/long-term-memory", json=payload.model_dump(exclude_none=True) ) @@ -322,7 +325,7 @@ async def search_long_term_memory( if namespace is None and self.config.default_namespace is not None: namespace = Namespace(eq=self.config.default_namespace) - payload = SearchPayload( + payload = SearchRequest( text=text, session_id=session_id, namespace=namespace, @@ -343,6 +346,201 @@ async def search_long_term_memory( response.raise_for_status() return LongTermMemoryResults(**response.json()) + async def memory_prompt( + self, + query: str, + session_id: str | None = None, + namespace: str | None = None, + window_size: int | None = None, + model_name: ModelNameLiteral | None = None, + context_window_max: int | None = None, + long_term_search: SearchRequest | None = None, + ) -> MemoryPromptResponse: + """ + Hydrate a user query with memory context and return a prompt + ready to send to an LLM. + + This method can retrieve relevant session history and long-term memories + to provide context for the query. + + Args: + query: The user's query text + session_id: Optional session ID to retrieve history from + namespace: Optional namespace for session and long-term memories + window_size: Optional number of messages to include from session history + model_name: Optional model name to determine context window size + context_window_max: Optional direct specification of context window max tokens + long_term_search: Optional SearchRequest for specific long-term memory filtering + + Returns: + MemoryPromptResponse containing a list of messages with context + + Raises: + httpx.HTTPStatusError: If the request fails or if neither session_id nor long_term_search is provided + """ + # Prepare the request payload + session_params = None + if session_id is not None: + session_params = SessionMemoryRequest( + session_id=session_id, + namespace=namespace or self.config.default_namespace, + window_size=window_size or 12, # Default from settings + model_name=model_name, + context_window_max=context_window_max, + ) + + # If no explicit long_term_search is provided but we have a query, create a basic one + if long_term_search is None and query: + # Use default namespace from config if none provided + _namespace = None + if namespace is not None: + _namespace = Namespace(eq=namespace) + elif self.config.default_namespace is not None: + _namespace = Namespace(eq=self.config.default_namespace) + + long_term_search = SearchRequest( + text=query, + namespace=_namespace, + ) + + # Create the request payload + payload = MemoryPromptRequest( + query=query, + session=session_params, + long_term_search=long_term_search, + ) + + # Make the API call + response = await self._client.post( + "/memory-prompt", json=payload.model_dump(exclude_none=True) + ) + response.raise_for_status() + data = response.json() + return MemoryPromptResponse(**data) + + async def hydrate_memory_prompt( + self, + query: str, + session_id: SessionId | dict[str, Any] | None = None, + namespace: Namespace | dict[str, Any] | None = None, + topics: Topics | dict[str, Any] | None = None, + entities: Entities | dict[str, Any] | None = None, + created_at: CreatedAt | dict[str, Any] | None = None, + last_accessed: LastAccessed | dict[str, Any] | None = None, + user_id: UserId | dict[str, Any] | None = None, + distance_threshold: float | None = None, + memory_type: MemoryType | dict[str, Any] | None = None, + limit: int = 10, + offset: int = 0, + window_size: int = 12, + model_name: ModelNameLiteral | None = None, + context_window_max: int | None = None, + ) -> MemoryPromptResponse: + """ + Hydrate a user query with relevant session history and long-term memories. + + This method enriches the user's query by retrieving: + 1. Context from the conversation session (if session_id is provided) + 2. Relevant long-term memories related to the query + + Args: + query: The user's query text + session_id: Optional filter for session ID + namespace: Optional filter for namespace + topics: Optional filter for topics in long-term memories + entities: Optional filter for entities in long-term memories + created_at: Optional filter for creation date + last_accessed: Optional filter for last access date + user_id: Optional filter for user ID + distance_threshold: Optional distance threshold for semantic search + memory_type: Optional filter for memory type + limit: Maximum number of long-term memory results (default: 10) + offset: Offset for pagination (default: 0) + window_size: Number of messages to include from session history (default: 12) + model_name: Optional model name to determine context window size + context_window_max: Optional direct specification of context window max tokens + + Returns: + MemoryPromptResponse containing a list of messages with context + + Raises: + httpx.HTTPStatusError: If the request fails + """ + # Convert dictionary filters to their proper filter objects if needed + if isinstance(session_id, dict): + session_id = SessionId(**session_id) + if isinstance(namespace, dict): + namespace = Namespace(**namespace) + if isinstance(topics, dict): + topics = Topics(**topics) + if isinstance(entities, dict): + entities = Entities(**entities) + if isinstance(created_at, dict): + created_at = CreatedAt(**created_at) + if isinstance(last_accessed, dict): + last_accessed = LastAccessed(**last_accessed) + if isinstance(user_id, dict): + user_id = UserId(**user_id) + if isinstance(memory_type, dict): + memory_type = MemoryType(**memory_type) + + # Apply default namespace if needed and no namespace filter specified + if namespace is None and self.config.default_namespace is not None: + namespace = Namespace(eq=self.config.default_namespace) + + # Extract session_id value if it exists + session_params = None + _session_id = None + if session_id and hasattr(session_id, "eq") and session_id.eq: + _session_id = session_id.eq + + if _session_id: + # Get namespace value if it exists + _namespace = None + if namespace and hasattr(namespace, "eq"): + _namespace = namespace.eq + elif self.config.default_namespace: + _namespace = self.config.default_namespace + + session_params = SessionMemoryRequest( + session_id=_session_id, + namespace=_namespace, + window_size=window_size, + model_name=model_name, + context_window_max=context_window_max, + ) + + # Create search request for long-term memory + search_payload = SearchRequest( + text=query, + session_id=session_id, + namespace=namespace, + topics=topics, + entities=entities, + created_at=created_at, + last_accessed=last_accessed, + user_id=user_id, + distance_threshold=distance_threshold, + memory_type=memory_type, + limit=limit, + offset=offset, + ) + + # Create the request payload + payload = MemoryPromptRequest( + query=query, + session=session_params, + long_term_search=search_payload, + ) + + # Make the API call + response = await self._client.post( + "/memory-prompt", json=payload.model_dump(exclude_none=True) + ) + response.raise_for_status() + data = response.json() + return MemoryPromptResponse(**data) + # Helper function to create a memory client async def create_memory_client( diff --git a/agent_memory_server/config.py b/agent_memory_server/config.py index e9801ea..bbeb791 100644 --- a/agent_memory_server/config.py +++ b/agent_memory_server/config.py @@ -1,6 +1,7 @@ import os from typing import Literal +import yaml from dotenv import load_dotenv from pydantic_settings import BaseSettings @@ -8,25 +9,42 @@ load_dotenv() +def load_yaml_settings(): + config_path = os.getenv("APP_CONFIG_FILE", "config.yaml") + if os.path.exists(config_path): + with open(config_path) as f: + return yaml.safe_load(f) or {} + return {} + + class Settings(BaseSettings): redis_url: str = "redis://localhost:6379" long_term_memory: bool = True window_size: int = 20 - openai_api_key: str = os.getenv("OPENAI_API_KEY", "") - anthropic_api_key: str = os.getenv("ANTHROPIC_API_KEY", "") + openai_api_key: str | None = None + anthropic_api_key: str | None = None generation_model: str = "gpt-4o-mini" embedding_model: str = "text-embedding-3-small" port: int = 8000 mcp_port: int = 9000 - # Topic and NER model settings - topic_model_source: Literal["NER", "LLM"] = "LLM" - topic_model: str = "MaartenGr/BERTopic_Wikipedia" # LLM model here if using LLM - ner_model: str = "dbmdz/bert-large-cased-finetuned-conll03-english" + # The server indexes messages in long-term memory by default. If this + # setting is enabled, we also extract discrete memories from message text + # and save them as separate long-term memory records. + enable_discrete_memory_extraction: bool = True + + # Topic modeling + topic_model_source: Literal["BERTopic", "LLM"] = "LLM" + topic_model: str = ( + "MaartenGr/BERTopic_Wikipedia" # Use an LLM model name here if using LLM + ) enable_topic_extraction: bool = True - enable_ner: bool = True top_k_topics: int = 3 + # Used for extracting entities from text + ner_model: str = "dbmdz/bert-large-cased-finetuned-conll03-english" + enable_ner: bool = True + # RedisVL Settings redisvl_distance_metric: str = "COSINE" redisvl_vector_dimensions: str = "1536" @@ -40,5 +58,11 @@ class Settings(BaseSettings): # Other Application settings log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO" + class Config: + env_file = ".env" + env_file_encoding = "utf-8" + -settings = Settings() +# Load YAML config first, then let env vars override +yaml_settings = load_yaml_settings() +settings = Settings(**yaml_settings) diff --git a/agent_memory_server/extraction.py b/agent_memory_server/extraction.py index ae29c6f..06176c8 100644 --- a/agent_memory_server/extraction.py +++ b/agent_memory_server/extraction.py @@ -146,7 +146,7 @@ async def extract_topics_llm( return topics -def extract_topics_ner(text: str, num_topics: int | None = None) -> list[str]: +def extract_topics_bertopic(text: str, num_topics: int | None = None) -> list[str]: """ Extract topics from text using the BERTopic model. @@ -193,12 +193,8 @@ async def handle_extraction(text: str) -> tuple[list[str], list[str]]: # Extract topics if enabled topics = [] if settings.enable_topic_extraction: - # Check if the topic_model_source setting exists and use appropriate function - if ( - hasattr(settings, "topic_model_source") - and settings.topic_model_source == "NER" - ): - topics = extract_topics_ner(text) + if settings.topic_model_source == "BERTopic": + topics = extract_topics_bertopic(text) else: topics = await extract_topics_llm(text) @@ -263,7 +259,10 @@ async def handle_extraction(text: str) -> tuple[list[str], list[str]]: """ -async def extract_discrete_memories(redis: Redis | None = None): +async def extract_discrete_memories( + redis: Redis | None = None, + deduplicate: bool = True, +): """ Extract episodic and semantic memories from text using an LLM. """ @@ -345,5 +344,5 @@ async def extract_discrete_memories(redis: Redis | None = None): await index_long_term_memories( long_term_memories, - deduplicate=True, + deduplicate=deduplicate, ) diff --git a/agent_memory_server/long_term_memory.py b/agent_memory_server/long_term_memory.py index 7a4dd88..4c40373 100644 --- a/agent_memory_server/long_term_memory.py +++ b/agent_memory_server/long_term_memory.py @@ -500,8 +500,6 @@ async def index_long_term_memories( memories: list[LongTermMemory], redis_client: Redis | None = None, deduplicate: bool = False, - deduplicate_hash: bool = True, - deduplicate_semantic: bool = True, vector_distance_threshold: float = 0.12, llm_client: Any = None, ) -> None: @@ -612,7 +610,14 @@ async def index_long_term_memories( await pipe.execute() logger.info(f"Indexed {len(processed_memories)} memories") - await background_tasks.add_task(extract_discrete_memories) + if settings.enable_discrete_memory_extraction: + # Extract discrete memories from the indexed messages and persist + # them as separate long-term memory records. This process also + # runs deduplication if requested. + await background_tasks.add_task( + extract_discrete_memories, + deduplicate=deduplicate, + ) async def search_long_term_memories( diff --git a/agent_memory_server/mcp.py b/agent_memory_server/mcp.py index a412418..17a94d5 100644 --- a/agent_memory_server/mcp.py +++ b/agent_memory_server/mcp.py @@ -1,16 +1,11 @@ -import asyncio import logging import os -import sys -from fastapi import HTTPException from mcp.server.fastmcp import FastMCP as _FastMCPBase -from mcp.server.fastmcp.prompts import base -from mcp.types import TextContent from agent_memory_server.api import ( create_long_term_memory as core_create_long_term_memory, - get_session_memory as core_get_session_memory, + memory_prompt as core_memory_prompt, search_long_term_memory as core_search_long_term_memory, ) from agent_memory_server.config import settings @@ -19,6 +14,7 @@ CreatedAt, Entities, LastAccessed, + MemoryType, Namespace, SessionId, Topics, @@ -26,10 +22,14 @@ ) from agent_memory_server.models import ( AckResponse, - CreateLongTermMemoryPayload, + CreateLongTermMemoryRequest, LongTermMemory, LongTermMemoryResults, - SearchPayload, + MemoryPromptRequest, + MemoryPromptResponse, + ModelNameLiteral, + SearchRequest, + SessionMemoryRequest, ) @@ -213,7 +213,7 @@ async def create_long_term_memories( if mem.namespace is None: mem.namespace = DEFAULT_NAMESPACE - payload = CreateLongTermMemoryPayload(memories=memories) + payload = CreateLongTermMemoryRequest(memories=memories) return await core_create_long_term_memory( payload, background_tasks=get_background_tasks() ) @@ -229,6 +229,7 @@ async def search_long_term_memory( created_at: CreatedAt | None = None, last_accessed: LastAccessed | None = None, user_id: UserId | None = None, + memory_type: MemoryType | None = None, distance_threshold: float | None = None, limit: int = 10, offset: int = 0, @@ -284,6 +285,7 @@ async def search_long_term_memory( created_at: Filter by creation date last_accessed: Filter by last access date user_id: Filter by user ID + memory_type: Filter by memory type distance_threshold: Distance threshold for semantic search limit: Maximum number of results offset: Offset for pagination @@ -292,7 +294,7 @@ async def search_long_term_memory( LongTermMemoryResults containing matched memories sorted by relevance """ try: - payload = SearchPayload( + payload = SearchRequest( text=text, session_id=session_id, namespace=namespace, @@ -301,6 +303,7 @@ async def search_long_term_memory( created_at=created_at, last_accessed=last_accessed, user_id=user_id, + memory_type=memory_type, distance_threshold=distance_threshold, limit=limit, offset=offset, @@ -321,24 +324,30 @@ async def search_long_term_memory( return results -# NOTE: Prompts don't support search filters in FastMCP, so we need to use a -# tool instead. +# Notes that exist outside of the docstring to avoid polluting the LLM prompt: +# 1. The "prompt" abstraction in FastAPI doesn't support search filters, so we use a tool. +# 2. Some applications, such as Cursor, get confused with nested objects in tool parameters, +# so we use a flat set of parameters instead. @mcp_app.tool() -async def hydrate_memory_prompt( - text: str | None = None, +async def memory_prompt( + query: str, session_id: SessionId | None = None, namespace: Namespace | None = None, + window_size: int = settings.window_size, + model_name: ModelNameLiteral | None = None, + context_window_max: int | None = None, topics: Topics | None = None, entities: Entities | None = None, created_at: CreatedAt | None = None, last_accessed: LastAccessed | None = None, user_id: UserId | None = None, + memory_type: MemoryType | None = None, distance_threshold: float | None = None, limit: int = 10, offset: int = 0, -) -> list[base.Message]: +) -> MemoryPromptResponse: """ - Hydrate a user prompt with relevant session history and long-term memories. + Hydrate a user query with relevant session history and long-term memories. CRITICAL: Use this tool for EVERY question that might benefit from memory context, especially when you don't have sufficient information to answer confidently. @@ -363,13 +372,13 @@ async def hydrate_memory_prompt( COMMON USAGE PATTERNS: ```python - 1. Basic search with just query text: + 1. Hydrate a user prompt with long-term memory search: hydrate_memory_prompt(text="What was my favorite color?") ``` - 2. Search with simple session filter: + 2. Hydrate a user prompt with long-term memory search and session filter: hydrate_memory_prompt( - text="What was my favorite color?", + text="What is my favorite color?", session_id={ "eq": "session_12345" }, @@ -378,7 +387,7 @@ async def hydrate_memory_prompt( } ) - 3. Search with complex filters: + 3. Hydrate a user prompt with long-term memory search and complex filters: hydrate_memory_prompt( text="What was my favorite color?", topics={ @@ -392,95 +401,51 @@ async def hydrate_memory_prompt( ``` Args: - - text: The user's query/message (required) - - session_id: Filter by session ID - - namespace: Filter by namespace - - topics: Filter by topics - - entities: Filter by entities - - created_at: Filter by creation date - - last_accessed: Filter by last access date - - user_id: Filter by user ID + - text: The user's query + - session_id: Add conversation history from a session + - namespace: Filter session and long-term memory namespace + - topics: Search for long-term memories matching topics + - entities: Search for long-term memories matching entities + - created_at: Search for long-term memories matching creation date + - last_accessed: Search for long-term memories matching last access date + - user_id: Search for long-term memories matching user ID - distance_threshold: Distance threshold for semantic search - - limit: Maximum number of results - - offset: Offset for pagination + - limit: Maximum number of long-term memory results + - offset: Offset for pagination of long-term memory results Returns: A list of messages, including memory context and the user's query """ - messages = [] - if session_id and session_id.eq: - try: - session_memory = await core_get_session_memory(session_id.eq) - except HTTPException: - session_memory = None - else: - session_memory = None - - if session_memory: - if session_memory.context: - messages.append( - base.AssistantMessage( - content=TextContent( - type="text", - text=f"## A summary of the conversation so far\n{session_memory.context}", - ), - ) - ) - for msg in session_memory.messages: - if msg.role == "user": - msg_class = base.UserMessage - else: - msg_class = base.AssistantMessage - messages.append( - msg_class( - content=TextContent(type="text", text=msg.content), - ) - ) - - try: - long_term_memories = await core_search_long_term_memory( - SearchPayload( - text=text, - session_id=session_id, - namespace=namespace, - topics=topics, - entities=entities, - created_at=created_at, - last_accessed=last_accessed, - user_id=user_id, - distance_threshold=distance_threshold, - limit=limit, - offset=offset, - ) + _session_id = session_id.eq if session_id and session_id.eq else None + session = None + + if _session_id is not None: + session = SessionMemoryRequest( + session_id=_session_id, + namespace=namespace.eq if namespace and namespace.eq else None, + window_size=window_size, + model_name=model_name, + context_window_max=context_window_max, ) - if long_term_memories.total > 0: - long_term_memories_text = "\n".join( - [f"- {m.text}" for m in long_term_memories.memories] - ) - messages.append( - base.AssistantMessage( - content=TextContent( - type="text", - text=f"## Long term memories related to the user's query\n {long_term_memories_text}", - ), - ) - ) - except Exception as e: - logger.error(f"Error searching long-term memory: {e}") - # Ensure text is not None - safe_text = text or "" - messages.append( - base.UserMessage( - content=TextContent(type="text", text=safe_text), - ) + search_payload = SearchRequest( + text=query, + session_id=session_id, + namespace=namespace, + topics=topics, + entities=entities, + created_at=created_at, + last_accessed=last_accessed, + user_id=user_id, + distance_threshold=distance_threshold, + memory_type=memory_type, + limit=limit, + offset=offset, ) + _params = {} + if session is not None: + _params["session"] = session + if search_payload is not None: + _params["long_term_search"] = search_payload - return messages - - -if __name__ == "__main__": - if len(sys.argv) > 1 and sys.argv[1] == "sse": - asyncio.run(mcp_app.run_sse_async()) - else: - asyncio.run(mcp_app.run_stdio_async()) + return await core_memory_prompt(params=MemoryPromptRequest(query=query, **_params)) diff --git a/agent_memory_server/models.py b/agent_memory_server/models.py index 44e2eb1..342ea79 100644 --- a/agent_memory_server/models.py +++ b/agent_memory_server/models.py @@ -2,8 +2,10 @@ import time from typing import Literal +from mcp.server.fastmcp.prompts import base from pydantic import BaseModel, Field +from agent_memory_server.config import settings from agent_memory_server.filters import ( CreatedAt, Entities, @@ -20,6 +22,33 @@ JSONTypes = str | float | int | bool | list | dict +# These should match the keys in MODEL_CONFIGS +ModelNameLiteral = Literal[ + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-4", + "gpt-4-32k", + "gpt-4o", + "gpt-4o-mini", + "o1", + "o1-mini", + "o3-mini", + "text-embedding-ada-002", + "text-embedding-3-small", + "text-embedding-3-large", + "claude-3-opus-20240229", + "claude-3-sonnet-20240229", + "claude-3-haiku-20240307", + "claude-3-5-sonnet-20240620", + "claude-3-7-sonnet-20250219", + "claude-3-5-sonnet-20241022", + "claude-3-5-haiku-20241022", + "claude-3-7-sonnet-latest", + "claude-3-5-sonnet-latest", + "claude-3-5-haiku-latest", + "claude-3-opus-latest", +] + class MemoryMessage(BaseModel): """A message in the memory system""" @@ -66,6 +95,14 @@ class SessionMemory(BaseModel): ) +class SessionMemoryRequest(BaseModel): + session_id: str + namespace: str | None = None + window_size: int = settings.window_size + model_name: ModelNameLiteral | None = None + context_window_max: int | None = None + + class SessionMemoryResponse(SessionMemory): """Response containing a session's memory""" @@ -155,7 +192,7 @@ class LongTermMemoryResultsResponse(LongTermMemoryResults): """Response containing long-term memory search results""" -class CreateLongTermMemoryPayload(BaseModel): +class CreateLongTermMemoryRequest(BaseModel): """Payload for creating a long-term memory""" memories: list[LongTermMemory] @@ -175,7 +212,7 @@ class HealthCheckResponse(BaseModel): now: int -class SearchPayload(BaseModel): +class SearchRequest(BaseModel): """Payload for long-term memory search""" text: str | None = Field( @@ -259,3 +296,25 @@ def get_filters(self): filters["memory_type"] = self.memory_type return filters + + +class MemoryPromptRequest(BaseModel): + query: str + session: SessionMemoryRequest | None = None + long_term_search: SearchRequest | None = None + + +class SystemMessage(base.Message): + """A system message""" + + role: Literal["system"] = "system" + + +class UserMessage(base.Message): + """A user message""" + + role: Literal["user"] = "user" + + +class MemoryPromptResponse(BaseModel): + messages: list[base.Message | SystemMessage] diff --git a/agent_memory_server/test_config.py b/agent_memory_server/test_config.py new file mode 100644 index 0000000..3342577 --- /dev/null +++ b/agent_memory_server/test_config.py @@ -0,0 +1,60 @@ +import tempfile + +import yaml + +from agent_memory_server.config import Settings, load_yaml_settings + + +def test_defaults(monkeypatch): + # Clear env vars + monkeypatch.delenv("APP_CONFIG_FILE", raising=False) + monkeypatch.delenv("redis_url", raising=False) + # No YAML file + monkeypatch.chdir(tempfile.gettempdir()) + settings = Settings() + assert settings.redis_url == "redis://localhost:6379" + assert settings.port == 8000 + assert settings.log_level == "INFO" + + +def test_yaml_loading(tmp_path, monkeypatch): + config = {"redis_url": "redis://test:6379", "port": 1234, "log_level": "DEBUG"} + yaml_path = tmp_path / "config.yaml" + with open(yaml_path, "w") as f: + yaml.dump(config, f) + monkeypatch.setenv("APP_CONFIG_FILE", str(yaml_path)) + # Remove env var overrides + monkeypatch.delenv("redis_url", raising=False) + monkeypatch.delenv("port", raising=False) + monkeypatch.delenv("log_level", raising=False) + loaded = load_yaml_settings() + settings = Settings(**loaded) + assert settings.redis_url == "redis://test:6379" + assert settings.port == 1234 + assert settings.log_level == "DEBUG" + + +def test_env_overrides_yaml(tmp_path, monkeypatch): + config = {"redis_url": "redis://yaml:6379", "port": 1111} + yaml_path = tmp_path / "config.yaml" + with open(yaml_path, "w") as f: + yaml.dump(config, f) + monkeypatch.setenv("APP_CONFIG_FILE", str(yaml_path)) + monkeypatch.setenv("redis_url", "redis://env:6379") + monkeypatch.setenv("port", "2222") + loaded = load_yaml_settings() + settings = Settings(**loaded) + # Env vars should override YAML + assert settings.redis_url == "redis://env:6379" + assert settings.port == 2222 # Pydantic auto-casts + + +def test_custom_config_path(tmp_path, monkeypatch): + config = {"redis_url": "redis://custom:6379"} + custom_path = tmp_path / "custom.yaml" + with open(custom_path, "w") as f: + yaml.dump(config, f) + monkeypatch.setenv("APP_CONFIG_FILE", str(custom_path)) + loaded = load_yaml_settings() + settings = Settings(**loaded) + assert settings.redis_url == "redis://custom:6379" diff --git a/tests/test_api.py b/tests/test_api.py index 020907a..65bd6a3 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -121,6 +121,12 @@ async def test_put_memory(self, client): assert "status" in data assert data["status"] == "ok" + updated_session = await client.get( + "/sessions/test-session/memory?namespace=test-namespace" + ) + assert updated_session.status_code == 200 + assert updated_session.json()["messages"] == payload["messages"] + @pytest.mark.requires_api_keys @pytest.mark.asyncio async def test_put_memory_stores_messages_in_long_term_memory( @@ -258,3 +264,268 @@ async def test_search(self, mock_search, client): assert data["memories"][1]["id_"] == "2" assert data["memories"][1]["text"] == "Assistant: Hi there!" assert data["memories"][1]["dist"] == 0.75 + + +@pytest.mark.requires_api_keys +class TestMemoryPromptEndpoint: + @patch("agent_memory_server.api.messages.get_session_memory") + @pytest.mark.asyncio + async def test_memory_prompt_with_session_id(self, mock_get_session_memory, client): + """Test the memory_prompt endpoint with only session_id provided""" + # Mock the session memory + mock_session_memory = SessionMemoryResponse( + messages=[ + MemoryMessage(role="user", content="Hello"), + MemoryMessage(role="assistant", content="Hi there"), + ], + context="Previous conversation context", + namespace="test-namespace", + tokens=150, + ) + mock_get_session_memory.return_value = mock_session_memory + + # Call the endpoint + query = "What's the weather like?" + response = await client.post( + "/memory-prompt", + json={ + "query": query, + "session": { + "session_id": "test-session", + "namespace": "test-namespace", + "window_size": 10, + "model_name": "gpt-4o", + "context_window_max": 1000, + }, + }, + ) + + # Check status code + assert response.status_code == 200 + + # Check response data + data = response.json() + assert isinstance(data, dict) + assert ( + len(data["messages"]) == 4 + ) # Context message + 2 session messages + query + + # Verify the messages content + assert data["messages"][0]["role"] == "system" + assert "Previous conversation context" in data["messages"][0]["content"]["text"] + assert data["messages"][1]["role"] == "user" + assert data["messages"][1]["content"]["text"] == "Hello" + assert data["messages"][2]["role"] == "assistant" + assert data["messages"][2]["content"]["text"] == "Hi there" + assert data["messages"][3]["role"] == "user" + assert data["messages"][3]["content"]["text"] == query + + @patch("agent_memory_server.api.long_term_memory.search_long_term_memories") + @pytest.mark.asyncio + async def test_memory_prompt_with_long_term_memory(self, mock_search, client): + """Test the memory_prompt endpoint with only long_term_search_payload provided""" + # Mock the long-term memory search + mock_search.return_value = LongTermMemoryResultsResponse( + memories=[ + LongTermMemoryResult(id_="1", text="User likes coffee", dist=0.25), + LongTermMemoryResult( + id_="2", text="User is allergic to peanuts", dist=0.35 + ), + ], + total=2, + ) + + # Prepare the payload + payload = { + "query": "What should I eat?", + "long_term_search": { + "text": "food preferences allergies", + }, + } + + # Call the endpoint + response = await client.post("/memory-prompt", json=payload) + + # Check status code + assert response.status_code == 200 + + # Check response data + data = response.json() + assert isinstance(data, dict) + assert len(data["messages"]) == 2 # Long-term memory message + query + + # Verify the messages content + assert data["messages"][0]["role"] == "system" + assert "Long term memories" in data["messages"][0]["content"]["text"] + assert "User likes coffee" in data["messages"][0]["content"]["text"] + assert "User is allergic to peanuts" in data["messages"][0]["content"]["text"] + assert data["messages"][1]["role"] == "user" + assert data["messages"][1]["content"]["text"] == "What should I eat?" + + @patch("agent_memory_server.api.messages.get_session_memory") + @patch("agent_memory_server.api.long_term_memory.search_long_term_memories") + @pytest.mark.asyncio + async def test_memory_prompt_with_both_sources( + self, mock_search, mock_get_session_memory, client + ): + """Test the memory_prompt endpoint with both session_id and long_term_search_payload""" + # Mock session memory + mock_session_memory = SessionMemoryResponse( + messages=[ + MemoryMessage(role="user", content="How do you make pasta?"), + MemoryMessage( + role="assistant", + content="Boil water, add pasta, cook until al dente.", + ), + ], + context="Cooking conversation", + namespace="test-namespace", + tokens=200, + ) + mock_get_session_memory.return_value = mock_session_memory + + # Mock the long-term memory search + mock_search.return_value = LongTermMemoryResultsResponse( + memories=[ + LongTermMemoryResult( + id_="1", text="User prefers gluten-free pasta", dist=0.3 + ), + ], + total=1, + ) + + # Prepare the payload + payload = { + "query": "What pasta should I buy?", + "session": { + "session_id": "test-session", + "namespace": "test-namespace", + }, + "long_term_search": { + "text": "pasta preferences", + }, + } + + # Call the endpoint + response = await client.post("/memory-prompt", json=payload) + + # Check status code + assert response.status_code == 200 + + # Check response data + data = response.json() + assert isinstance(data, dict) + assert ( + len(data["messages"]) == 5 + ) # Context + 2 session messages + long-term memory + query + + # Verify the messages content (order matters) + assert data["messages"][0]["role"] == "system" + assert "Cooking conversation" in data["messages"][0]["content"]["text"] + assert data["messages"][1]["role"] == "user" + assert data["messages"][1]["content"]["text"] == "How do you make pasta?" + assert data["messages"][2]["role"] == "assistant" + assert ( + data["messages"][2]["content"]["text"] + == "Boil water, add pasta, cook until al dente." + ) + assert data["messages"][3]["role"] == "system" + assert "Long term memories" in data["messages"][3]["content"]["text"] + assert ( + "User prefers gluten-free pasta" in data["messages"][3]["content"]["text"] + ) + assert data["messages"][4]["role"] == "user" + assert data["messages"][4]["content"]["text"] == "What pasta should I buy?" + + @pytest.mark.asyncio + async def test_memory_prompt_without_required_params(self, client): + """Test the memory_prompt endpoint without required parameters""" + # Call the endpoint without session or long_term_search + response = await client.post("/memory-prompt", json={"query": "test"}) + + # Check status code (should be 400 Bad Request) + assert response.status_code == 400 + + # Check error message + data = response.json() + assert "detail" in data + assert "Either session or long_term_search must be provided" in data["detail"] + + @patch("agent_memory_server.api.messages.get_session_memory") + @pytest.mark.asyncio + async def test_memory_prompt_session_not_found( + self, mock_get_session_memory, client + ): + """Test the memory_prompt endpoint when session is not found""" + # Mock the session memory to return None (session not found) + mock_get_session_memory.return_value = None + + # Call the endpoint + query = "What's the weather like?" + response = await client.post( + "/memory-prompt", + json={ + "query": query, + "session": { + "session_id": "nonexistent-session", + "namespace": "test-namespace", + }, + }, + ) + + # Check status code (should be successful) + assert response.status_code == 200 + + # Check response data (should only contain the query) + data = response.json() + assert isinstance(data, dict) + assert len(data["messages"]) == 1 + assert data["messages"][0]["role"] == "user" + assert data["messages"][0]["content"]["text"] == query + + @patch("agent_memory_server.api.messages.get_session_memory") + @patch("agent_memory_server.api.get_model_config") + @pytest.mark.asyncio + async def test_memory_prompt_with_model_name( + self, mock_get_model_config, mock_get_session_memory, client + ): + """Test the memory_prompt endpoint with model_name parameter""" + # Mock the model config + model_config = MagicMock() + model_config.max_tokens = 4000 + mock_get_model_config.return_value = model_config + + # Mock the session memory + mock_session_memory = SessionMemoryResponse( + messages=[ + MemoryMessage(role="user", content="Hello"), + MemoryMessage(role="assistant", content="Hi there"), + ], + context="Previous context", + namespace="test-namespace", + tokens=150, + ) + mock_get_session_memory.return_value = mock_session_memory + + # Call the endpoint with model_name + query = "What's the weather like?" + response = await client.post( + "/memory-prompt", + json={ + "query": query, + "session": { + "session_id": "test-session", + "model_name": "gpt-4o", + }, + }, + ) + + # Check the model config was used + mock_get_model_config.assert_called_once_with("gpt-4o") + + # Check status code + assert response.status_code == 200 + + # Verify the effective window size was used in get_session_memory + mock_get_session_memory.assert_called_once() + assert mock_get_session_memory.call_args[1]["window_size"] <= 4000 diff --git a/tests/test_client_api.py b/tests/test_client_api.py index ba9d314..eb94a1c 100644 --- a/tests/test_client_api.py +++ b/tests/test_client_api.py @@ -5,23 +5,27 @@ """ from collections.abc import AsyncGenerator -from unittest.mock import patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastapi import FastAPI from httpx import ASGITransport, AsyncClient +from mcp.server.fastmcp.prompts import base +from mcp.types import TextContent from agent_memory_server.api import router as memory_router from agent_memory_server.client.api import MemoryAPIClient, MemoryClientConfig -from agent_memory_server.filters import Namespace +from agent_memory_server.filters import Namespace, SessionId, Topics from agent_memory_server.healthcheck import router as health_router from agent_memory_server.models import ( LongTermMemory, LongTermMemoryResult, LongTermMemoryResultsResponse, MemoryMessage, + MemoryPromptResponse, SessionMemory, SessionMemoryResponse, + SystemMessage, ) @@ -226,64 +230,206 @@ async def test_client_with_context_manager(memory_app: FastAPI): # The client will be automatically closed when the context block exits -# Example usage is left in the file for documentation purposes, -# but commented out so it doesn't run during tests -""" -# This example demonstrates basic usage of the API client -if __name__ == "__main__": - async def example(): - # Create a client - base_url = "http://localhost:8000" # Adjust to your server URL - - # Using context manager for automatic cleanup - async with MemoryAPIClient( - MemoryClientConfig( - base_url=base_url, - default_namespace="example-namespace" - ) - ) as client: - # Check server health - health = await client.health_check() - print(f"Server is healthy, current time: {health.now}") - - # Store a conversation - session_id = "example-session" - memory = SessionMemory( - messages=[ - MemoryMessage(role="user", content="What is the weather like today?"), - MemoryMessage(role="assistant", content="It's sunny and warm!"), - ] - ) - await client.put_session_memory(session_id, memory) - print(f"Stored conversation in session {session_id}") - - # Retrieve the conversation - session = await client.get_session_memory(session_id) - print(f"Retrieved {len(session.messages)} messages from session {session_id}") - - # Create long-term memory - memories = [ - LongTermMemory( - text="User lives in San Francisco", - topics=["location", "personal_info"], - ), - ] - await client.create_long_term_memory(memories) - print("Created long-term memory") - - # Search for relevant memories - results = await client.search_long_term_memory( - text="Where does the user live?", - limit=5, - ) - print(f"Found {results.total} relevant memories") - for memory in results.memories: - print(f"- {memory.text} (relevance: {1.0 - memory.dist:.2f})") +@pytest.mark.asyncio +async def test_memory_prompt(memory_test_client: MemoryAPIClient): + """Test the memory_prompt method""" + session_id = "test-client-session" + query = "What was my favorite color?" - # Clean up - await client.delete_session_memory(session_id) - print(f"Deleted session {session_id}") + # Create expected response + expected_messages = [ + base.UserMessage( + content=TextContent(type="text", text="What is your favorite color?"), + ), + base.AssistantMessage( + content=TextContent(type="text", text="I like blue, how about you?"), + ), + base.UserMessage( + content=TextContent(type="text", text=query), + ), + ] - # Run the example - asyncio.run(example()) -""" + # Create expected response payload + expected_response = MemoryPromptResponse(messages=expected_messages) + + # Mock the HTTP client's post method directly + with patch.object(memory_test_client._client, "post") as mock_post: + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock(return_value=None) + mock_response.json = MagicMock(return_value=expected_response.model_dump()) + mock_post.return_value = mock_response + + # Test the client method + response = await memory_test_client.memory_prompt( + query=query, + session_id=session_id, + namespace="test-namespace", + window_size=5, + model_name="gpt-4o", + context_window_max=4000, + ) + + # Verify the response + assert len(response.messages) == 3 + assert isinstance(response.messages[0].content, TextContent) + assert response.messages[0].content.text.startswith( + "What is your favorite color?" + ) + assert isinstance(response.messages[-1].content, TextContent) + assert response.messages[-1].content.text == query + + # Test without session_id (only semantic search) + mock_post.reset_mock() + mock_post.return_value = mock_response + + response = await memory_test_client.memory_prompt( + query=query, + ) + + # Verify the response is the same (it's mocked) + assert len(response.messages) == 3 + + +@pytest.mark.asyncio +async def test_hydrate_memory_prompt(memory_test_client: MemoryAPIClient): + """Test the hydrate_memory_prompt method with filters""" + query = "What was my favorite color?" + + # Create expected response + expected_messages = [ + base.AssistantMessage( + content=TextContent( + type="text", + text="The user's favorite color is blue", + ), + ), + base.UserMessage( + content=TextContent(type="text", text=query), + ), + ] + + # Create expected response payload + expected_response = MemoryPromptResponse(messages=expected_messages) + + # Mock the HTTP client's post method directly + with patch.object(memory_test_client._client, "post") as mock_post: + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock(return_value=None) + mock_response.json = MagicMock(return_value=expected_response.model_dump()) + mock_post.return_value = mock_response + + # Test with filter dictionaries + response = await memory_test_client.hydrate_memory_prompt( + query=query, + session_id={"eq": "test-session"}, + namespace={"eq": "test-namespace"}, + topics={"any": ["preferences", "colors"]}, + limit=5, + ) + + # Verify the response + assert len(response.messages) == 2 + assert isinstance(response.messages[0].content, TextContent) + assert "favorite color" in response.messages[0].content.text + assert isinstance(response.messages[1].content, TextContent) + assert response.messages[1].content.text == query + + # Test with filter objects + mock_post.reset_mock() + mock_post.return_value = mock_response + + response = await memory_test_client.hydrate_memory_prompt( + query=query, + session_id=SessionId(eq="test-session"), + namespace=Namespace(eq="test-namespace"), + topics=Topics(any=["preferences"]), + window_size=10, + model_name="gpt-4o", + ) + + # Response should be the same because it's mocked + assert len(response.messages) == 2 + + # Test with no filters (just query) + mock_post.reset_mock() + mock_post.return_value = mock_response + + response = await memory_test_client.hydrate_memory_prompt( + query=query, + ) + + # Response should still be the same (mocked) + assert len(response.messages) == 2 + + +@pytest.mark.asyncio +async def test_memory_prompt_integration(memory_test_client: MemoryAPIClient): + """Test the memory_prompt method with both session and long-term search""" + session_id = "test-client-session" + query = "What was my favorite color?" + + # Create expected response with both session and LTM content + expected_messages = [ + SystemMessage( + content=TextContent( + type="text", + text="## A summary of the conversation so far\nPrevious conversation about website design preferences.", + ), + ), + base.UserMessage( + content=TextContent( + type="text", text="What is a good color for a website?" + ), + ), + base.AssistantMessage( + content=TextContent( + type="text", + text="It depends on the website's purpose. Blue is often used for professional sites.", + ), + ), + SystemMessage( + content=TextContent( + type="text", + text="## Long term memories related to the user's query\n - The user's favorite color is blue", + ), + ), + base.UserMessage( + content=TextContent(type="text", text=query), + ), + ] + + # Create expected response payload + expected_response = MemoryPromptResponse(messages=expected_messages) + + # Mock the HTTP client's post method directly + with patch.object(memory_test_client._client, "post") as mock_post: + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock(return_value=None) + mock_response.json = MagicMock(return_value=expected_response.model_dump()) + mock_post.return_value = mock_response + + # Let the client method run with our mocked response + response = await memory_test_client.memory_prompt( + query=query, + session_id=session_id, + namespace="test-namespace", + ) + + # Check that both session memory and LTM are in the response + assert len(response.messages) == 5 + + # Extract text from contents + message_texts = [] + for m in response.messages: + if isinstance(m.content, TextContent): + message_texts.append(m.content.text) + + # The messages should include at least one from the session + assert any("website" in text for text in message_texts) + # And at least one from LTM + assert any("favorite color is blue" in text for text in message_texts) + # And the query itself + assert query in message_texts[-1] diff --git a/tests/test_extraction.py b/tests/test_extraction.py index 73d39d3..2bab667 100644 --- a/tests/test_extraction.py +++ b/tests/test_extraction.py @@ -6,7 +6,8 @@ from agent_memory_server.config import settings from agent_memory_server.extraction import ( extract_entities, - extract_topics_ner, + extract_topics_bertopic, + extract_topics_llm, handle_extraction, ) @@ -45,7 +46,7 @@ async def test_extract_topics_success(self, mock_get_topic_model, mock_bertopic) mock_get_topic_model.return_value = mock_bertopic text = "Discussion about AI technology and business" - topics = extract_topics_ner(text) + topics = extract_topics_bertopic(text) assert set(topics) == {"technology", "business"} mock_bertopic.transform.assert_called_once_with([text]) @@ -58,7 +59,7 @@ async def test_extract_topics_no_valid_topics( mock_bertopic.transform.return_value = (np.array([-1]), np.array([0.0])) mock_get_topic_model.return_value = mock_bertopic - topics = extract_topics_ner("Test message") + topics = extract_topics_bertopic("Test message") assert topics == [] mock_bertopic.transform.assert_called_once() @@ -135,3 +136,106 @@ async def test_handle_extraction_disabled_features( # Restore settings settings.enable_topic_extraction = original_topic_setting settings.enable_ner = original_ner_setting + + +@pytest.mark.requires_api_keys +class TestTopicExtractionIntegration: + @pytest.mark.asyncio + async def test_bertopic_integration(self): + """Integration test for BERTopic topic extraction (skipped if not available)""" + + # Save and set topic_model_source + original_source = settings.topic_model_source + settings.topic_model_source = "BERTopic" + sample_text = ( + "OpenAI and Google are leading companies in artificial intelligence." + ) + try: + try: + # Try to import BERTopic and check model loading + topics = extract_topics_bertopic(sample_text) + # print(f"[DEBUG] BERTopic returned topics: {topics}") + except Exception as e: + pytest.skip(f"BERTopic integration test skipped: {e}") + assert isinstance(topics, list) + expected_keywords = { + "generative", + "transformer", + "neural", + "learning", + "trained", + "multimodal", + "generates", + "models", + "encoding", + "text", + } + assert any(t.lower() in expected_keywords for t in topics) + finally: + settings.topic_model_source = original_source + + @pytest.mark.asyncio + async def test_llm_integration(self): + """Integration test for LLM-based topic extraction (skipped if no API key)""" + + # Save and set topic_model_source + original_source = settings.topic_model_source + settings.topic_model_source = "LLM" + sample_text = ( + "OpenAI and Google are leading companies in artificial intelligence." + ) + try: + # Check for API key + if not (settings.openai_api_key or settings.anthropic_api_key): + pytest.skip("No LLM API key available for integration test.") + topics = await extract_topics_llm(sample_text) + assert isinstance(topics, list) + assert any( + t.lower() in ["technology", "business", "artificial intelligence"] + for t in topics + ) + finally: + settings.topic_model_source = original_source + + +class TestHandleExtractionPathSelection: + @pytest.mark.asyncio + @patch("agent_memory_server.extraction.extract_topics_bertopic") + @patch("agent_memory_server.extraction.extract_topics_llm") + async def test_handle_extraction_path_selection( + self, mock_extract_topics_llm, mock_extract_topics_bertopic + ): + """Test that handle_extraction uses the correct extraction path based on settings.topic_model_source""" + + sample_text = ( + "OpenAI and Google are leading companies in artificial intelligence." + ) + original_source = settings.topic_model_source + original_enable_topic_extraction = settings.enable_topic_extraction + original_enable_ner = settings.enable_ner + try: + # Enable topic extraction and disable NER for clarity + settings.enable_topic_extraction = True + settings.enable_ner = False + + # Test BERTopic path + settings.topic_model_source = "BERTopic" + mock_extract_topics_bertopic.return_value = ["technology"] + mock_extract_topics_llm.return_value = ["should not be called"] + topics, _ = await handle_extraction(sample_text) + mock_extract_topics_bertopic.assert_called_once() + mock_extract_topics_llm.assert_not_called() + assert topics == ["technology"] + mock_extract_topics_bertopic.reset_mock() + + # Test LLM path + settings.topic_model_source = "LLM" + mock_extract_topics_llm.return_value = ["ai"] + topics, _ = await handle_extraction(sample_text) + mock_extract_topics_llm.assert_called_once() + mock_extract_topics_bertopic.assert_not_called() + assert topics == ["ai"] + finally: + settings.topic_model_source = original_source + settings.enable_topic_extraction = original_enable_topic_extraction + settings.enable_ner = original_enable_ner diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 350bcce..aa08286 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -10,6 +10,10 @@ from agent_memory_server.mcp import mcp_app from agent_memory_server.models import ( LongTermMemory, + LongTermMemoryResult, + MemoryPromptRequest, + MemoryPromptResponse, + SystemMessage, ) @@ -86,34 +90,47 @@ async def test_memory_prompt(self, session, mcp_test_setup): """Test memory prompt with various parameter combinations.""" async with client_session(mcp_app._mcp_server) as client: prompt = await client.call_tool( - "hydrate_memory_prompt", + "memory_prompt", { - "text": "Test query", + "query": "Test query", "session_id": {"eq": session}, "namespace": {"eq": "test-namespace"}, }, ) assert isinstance(prompt, CallToolResult) - # Parse the response content - ensure we're getting text content assert prompt.content[0].type == "text" - message = json.loads(prompt.content[0].text) + messages = json.loads(prompt.content[0].text) - # The result should be a dictionary with content and role - assert isinstance(message, dict) - assert "content" in message - assert "role" in message + assert isinstance(messages, dict) + assert "messages" in messages + assert len(messages["messages"]) == 5 + + # The returned messages structure is: + # 0: system (summary) + # 1: user ("Hello") + # 2: assistant ("Hi there") + # 3: system (long term memories) + # 4: user ("Test query") + assert messages["messages"][0]["role"] == "system" + assert messages["messages"][0]["content"]["type"] == "text" + assert "summary" in messages["messages"][0]["content"]["text"] - # Check the message content and role - accept either user or assistant roles - assert message["role"] in ["user", "assistant"] - assert message["content"]["type"] == "text" + assert messages["messages"][1]["role"] == "user" + assert messages["messages"][1]["content"]["type"] == "text" + assert messages["messages"][1]["content"]["text"] == "Hello" - # If it's an assistant message, check for some basic structure - if message["role"] == "assistant": - assert "Long term memories" in message["content"]["text"] - # If it's a user message, it should contain the original query - else: - assert "Test query" in message["content"]["text"] + assert messages["messages"][2]["role"] == "assistant" + assert messages["messages"][2]["content"]["type"] == "text" + assert messages["messages"][2]["content"]["text"] == "Hi there" + + assert messages["messages"][3]["role"] == "system" + assert messages["messages"][3]["content"]["type"] == "text" + assert "Long term memories" in messages["messages"][3]["content"]["text"] + + assert messages["messages"][4]["role"] == "user" + assert messages["messages"][4]["content"]["type"] == "text" + assert "Test query" in messages["messages"][4]["content"]["text"] @pytest.mark.asyncio async def test_memory_prompt_error_handling(self, session, mcp_test_setup): @@ -121,10 +138,10 @@ async def test_memory_prompt_error_handling(self, session, mcp_test_setup): async with client_session(mcp_app._mcp_server) as client: # Test with a non-existent session id prompt = await client.call_tool( - "hydrate_memory_prompt", + "memory_prompt", { - "text": "Test query", - "session_id": {"eq": "non-existent"}, + "query": "Test query", + "session": {"session_id": {"eq": "non-existent"}}, "namespace": {"eq": "test-namespace"}, }, ) @@ -134,15 +151,18 @@ async def test_memory_prompt_error_handling(self, session, mcp_test_setup): assert prompt.content[0].type == "text" message = json.loads(prompt.content[0].text) - # The result should be a dictionary with content and role + # The result should be a dictionary containing messages, each with content and role assert isinstance(message, dict) - assert "content" in message - assert "role" in message + assert "messages" in message # Check that we have a user message with the test query - assert message["role"] == "user" - assert message["content"]["type"] == "text" - assert message["content"]["text"] == "Test query" + assert message["messages"][0]["role"] == "system" + assert message["messages"][0]["content"]["type"] == "text" + assert "Long term memories" in message["messages"][0]["content"]["text"] + + assert message["messages"][1]["role"] == "user" + assert message["messages"][1]["content"]["type"] == "text" + assert "Test query" in message["messages"][1]["content"]["text"] @pytest.mark.asyncio async def test_default_namespace_injection(self, monkeypatch): @@ -150,7 +170,6 @@ async def test_default_namespace_injection(self, monkeypatch): Ensure that when default_namespace is set on mcp_app, search_long_term_memory injects it automatically. """ from agent_memory_server.models import ( - LongTermMemoryResult, LongTermMemoryResults, ) @@ -198,3 +217,60 @@ async def fake_core_search(payload): finally: # Restore original namespace mcp_app.default_namespace = original_ns + + @pytest.mark.asyncio + async def test_memory_prompt_parameter_passing(self, session, monkeypatch): + """ + Test that memory_prompt correctly passes parameters to core_memory_prompt. + This test verifies the implementation details to catch bugs like the _params issue. + """ + # Capture the parameters passed to core_memory_prompt + captured_params = {} + + async def mock_core_memory_prompt(params: MemoryPromptRequest): + captured_params["query"] = params.query + captured_params["session"] = params.session + captured_params["long_term_search"] = params.long_term_search + + # Return a minimal valid response + return MemoryPromptResponse( + messages=[ + SystemMessage(content={"type": "text", "text": "Test response"}) + ] + ) + + # Patch the core function + monkeypatch.setattr( + "agent_memory_server.mcp.core_memory_prompt", mock_core_memory_prompt + ) + + async with client_session(mcp_app._mcp_server) as client: + prompt = await client.call_tool( + "memory_prompt", + { + "query": "Test query", + "session_id": {"eq": session}, + "namespace": {"eq": "test-namespace"}, + "topics": {"any": ["test-topic"]}, + "entities": {"any": ["test-entity"]}, + "limit": 5, + }, + ) + + # Verify the tool was called successfully + assert isinstance(prompt, CallToolResult) + + # Verify that core_memory_prompt was called with the correct parameters + assert captured_params["query"] == "Test query" + + # Verify session parameters were passed correctly + assert captured_params["session"] is not None + assert captured_params["session"].session_id == session + assert captured_params["session"].namespace == "test-namespace" + + # Verify long_term_search parameters were passed correctly + assert captured_params["long_term_search"] is not None + assert captured_params["long_term_search"].text == "Test query" + assert captured_params["long_term_search"].limit == 5 + assert captured_params["long_term_search"].topics is not None + assert captured_params["long_term_search"].entities is not None diff --git a/tests/test_models.py b/tests/test_models.py index 0ede2c0..727c7e2 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -10,7 +10,7 @@ from agent_memory_server.models import ( LongTermMemoryResult, MemoryMessage, - SearchPayload, + SearchRequest, SessionMemory, SessionMemoryResponse, ) @@ -127,7 +127,7 @@ def test_search_payload_with_filter_objects(self): user_id = UserId(eq="test-user") # Create payload with filter objects - payload = SearchPayload( + payload = SearchRequest( text="Test query", session_id=session_id, namespace=namespace,