Skip to content

Commit 7eb6416

Browse files
committed
Support keyword and hybrid memory search
1 parent 70f6d2d commit 7eb6416

File tree

22 files changed

+650
-65
lines changed

22 files changed

+650
-65
lines changed

agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/longtermemory/MemoryRecordResult.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
public class MemoryRecordResult extends MemoryRecord {
77

88
private double dist;
9+
private Double score;
10+
private String scoreType;
911

1012
public MemoryRecordResult() {
1113
super();
@@ -19,10 +21,29 @@ public void setDist(double dist) {
1921
this.dist = dist;
2022
}
2123

24+
public Double getScore() {
25+
return score;
26+
}
27+
28+
public void setScore(Double score) {
29+
this.score = score;
30+
}
31+
32+
public String getScoreType() {
33+
return scoreType;
34+
}
35+
36+
@com.fasterxml.jackson.annotation.JsonProperty("score_type")
37+
public void setScoreType(String scoreType) {
38+
this.scoreType = scoreType;
39+
}
40+
2241
@Override
2342
public String toString() {
2443
return "MemoryRecordResult{" +
2544
"dist=" + dist +
45+
", score=" + score +
46+
", scoreType='" + scoreType + '\'' +
2647
", " + super.toString() +
2748
'}';
2849
}

agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/longtermemory/SearchRequest.java

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,16 @@ public class SearchRequest {
1515
@Nullable
1616
private String text;
1717

18+
@Nullable
19+
@JsonProperty("search_mode")
20+
private String searchMode;
21+
22+
@JsonProperty("hybrid_alpha")
23+
private double hybridAlpha = 0.7;
24+
25+
@JsonProperty("text_scorer")
26+
private String textScorer = "BM25STD";
27+
1828
@Nullable
1929
@JsonProperty("session_id")
2030
private String sessionId;
@@ -85,6 +95,32 @@ public void setText(@Nullable String text) {
8595
this.text = text;
8696
}
8797

98+
@Nullable
99+
public String getSearchMode() {
100+
return searchMode;
101+
}
102+
103+
public void setSearchMode(@Nullable String searchMode) {
104+
this.searchMode = searchMode;
105+
}
106+
107+
public double getHybridAlpha() {
108+
return hybridAlpha;
109+
}
110+
111+
public void setHybridAlpha(double hybridAlpha) {
112+
this.hybridAlpha = hybridAlpha;
113+
}
114+
115+
@Nullable
116+
public String getTextScorer() {
117+
return textScorer;
118+
}
119+
120+
public void setTextScorer(@Nullable String textScorer) {
121+
this.textScorer = textScorer;
122+
}
123+
88124
@Nullable
89125
public String getSessionId() {
90126
return sessionId;
@@ -231,6 +267,9 @@ public void setServerSideRecency(@Nullable Boolean serverSideRecency) {
231267
public String toString() {
232268
return "SearchRequest{" +
233269
"text='" + text + '\'' +
270+
", searchMode='" + searchMode + '\'' +
271+
", hybridAlpha=" + hybridAlpha +
272+
", textScorer='" + textScorer + '\'' +
234273
", sessionId='" + sessionId + '\'' +
235274
", namespace='" + namespace + '\'' +
236275
", topics=" + topics +
@@ -269,6 +308,21 @@ public Builder text(@Nullable String text) {
269308
return this;
270309
}
271310

311+
public Builder searchMode(@Nullable String searchMode) {
312+
request.searchMode = searchMode;
313+
return this;
314+
}
315+
316+
public Builder hybridAlpha(double hybridAlpha) {
317+
request.hybridAlpha = hybridAlpha;
318+
return this;
319+
}
320+
321+
public Builder textScorer(@Nullable String textScorer) {
322+
request.textScorer = textScorer;
323+
return this;
324+
}
325+
272326
public Builder sessionId(@Nullable String sessionId) {
273327
request.sessionId = sessionId;
274328
return this;

agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/services/LongTermMemoryService.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ public MemoryRecordResults searchLongTermMemories(@NotNull SearchRequest request
7474
// Build payload
7575
Map<String, Object> payload = new HashMap<>();
7676
payload.put("text", request.getText());
77+
payload.put("search_mode", request.getSearchMode() != null ? request.getSearchMode() : "semantic");
78+
payload.put("hybrid_alpha", request.getHybridAlpha());
79+
payload.put("text_scorer", request.getTextScorer() != null ? request.getTextScorer() : "BM25STD");
7780
payload.put("limit", request.getLimit());
7881
payload.put("offset", request.getOffset());
7982

agent-memory-client/agent-memory-client-js/src/client.test.ts

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,21 @@ describe("MemoryAPIClient", () => {
577577
expect(callBody.distance_threshold).toBe(0.5);
578578
});
579579

580+
it("should include search mode parameters in request body", async () => {
581+
mockFetch = createMockFetch({ memories: [], total: 0 });
582+
client["fetchFn"] = mockFetch;
583+
await client.searchLongTermMemory({
584+
text: "test",
585+
searchMode: "hybrid",
586+
hybridAlpha: 0.55,
587+
textScorer: "BM25",
588+
});
589+
const callBody = JSON.parse(mockFetch.mock.calls[0][1].body);
590+
expect(callBody.search_mode).toBe("hybrid");
591+
expect(callBody.hybrid_alpha).toBe(0.55);
592+
expect(callBody.text_scorer).toBe("BM25");
593+
});
594+
580595
it("should handle SessionId filter class", async () => {
581596
const { SessionId } = await import("./filters");
582597
mockFetch = createMockFetch({ memories: [], total: 0 });

agent-memory-client/agent-memory-client-js/src/client.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ export interface MemoryClientConfig {
7171
*/
7272
export interface SearchOptions {
7373
text: string;
74+
searchMode?: "semantic" | "keyword" | "hybrid";
75+
hybridAlpha?: number;
76+
textScorer?: string;
7477
sessionId?: SessionId | { eq?: string; in_?: string[]; not_eq?: string; not_in?: string[] };
7578
namespace?: Namespace | { eq?: string; in_?: string[]; not_eq?: string; not_in?: string[] };
7679
topics?: Topics | { any?: string[]; all?: string[]; none?: string[] };
@@ -382,6 +385,9 @@ export class MemoryAPIClient {
382385
async searchLongTermMemory(options: SearchOptions): Promise<MemoryRecordResults> {
383386
const body: Record<string, unknown> = {
384387
text: options.text,
388+
search_mode: options.searchMode ?? "semantic",
389+
hybrid_alpha: options.hybridAlpha ?? 0.7,
390+
text_scorer: options.textScorer ?? "BM25STD",
385391
limit: options.limit,
386392
offset: options.offset,
387393
distance_threshold: options.distanceThreshold,

agent-memory-client/agent-memory-client-js/src/models.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,10 @@ export interface SessionListResponse {
202202
export interface MemoryRecordResult extends MemoryRecord {
203203
/** Distance/similarity score */
204204
dist: number;
205+
/** Normalized relevance score for the selected search mode */
206+
score?: number | null;
207+
/** Search mode used to produce the normalized score */
208+
score_type?: "semantic" | "keyword" | "hybrid" | null;
205209
}
206210

207211
/**
@@ -267,6 +271,9 @@ export interface MemoryPromptRequest {
267271
*/
268272
export interface SearchRequestParams {
269273
text?: string;
274+
search_mode?: "semantic" | "keyword" | "hybrid";
275+
hybrid_alpha?: number;
276+
text_scorer?: string;
270277
session_id?: SessionIdFilter | null;
271278
namespace?: NamespaceFilter | null;
272279
topics?: TopicsFilter | null;

agent-memory-client/agent_memory_client/client.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
MemoryTypeEnum,
4949
ModelNameLiteral,
5050
RecencyConfig,
51+
SearchModeEnum,
5152
SessionListResponse,
5253
SummaryView,
5354
SummaryViewPartitionResult,
@@ -1033,6 +1034,9 @@ async def get_task(self, task_id: str) -> Task | None:
10331034
async def search_long_term_memory(
10341035
self,
10351036
text: str,
1037+
search_mode: SearchModeEnum | str = SearchModeEnum.SEMANTIC,
1038+
hybrid_alpha: float = 0.7,
1039+
text_scorer: str = "BM25STD",
10361040
session_id: SessionId | dict[str, Any] | None = None,
10371041
namespace: Namespace | dict[str, Any] | None = None,
10381042
topics: Topics | dict[str, Any] | None = None,
@@ -1048,10 +1052,13 @@ async def search_long_term_memory(
10481052
optimize_query: bool = False,
10491053
) -> MemoryRecordResults:
10501054
"""
1051-
Search long-term memories using semantic search and filters.
1055+
Search long-term memories using semantic, keyword, or hybrid search.
10521056
10531057
Args:
1054-
text: Query for vector search - will be used for semantic similarity matching
1058+
text: Query text used for semantic, keyword, or hybrid search
1059+
search_mode: Search strategy to use
1060+
hybrid_alpha: Weight assigned to vector similarity in hybrid search
1061+
text_scorer: Redis full-text scoring algorithm for keyword and hybrid search
10551062
session_id: Optional session ID filter
10561063
namespace: Optional namespace filter
10571064
topics: Optional topics filter
@@ -1141,6 +1148,13 @@ async def search_long_term_memory(
11411148
payload["memory_type"] = memory_type.model_dump(exclude_none=True)
11421149
if distance_threshold is not None:
11431150
payload["distance_threshold"] = distance_threshold
1151+
payload["search_mode"] = (
1152+
search_mode.value
1153+
if isinstance(search_mode, SearchModeEnum)
1154+
else str(search_mode)
1155+
)
1156+
payload["hybrid_alpha"] = hybrid_alpha
1157+
payload["text_scorer"] = text_scorer
11441158

11451159
# Add recency config if provided
11461160
if recency is not None:
@@ -1185,6 +1199,7 @@ async def search_long_term_memory(
11851199
async def search_memory_tool(
11861200
self,
11871201
query: str,
1202+
search_mode: SearchModeEnum | str = SearchModeEnum.SEMANTIC,
11881203
topics: Sequence[str] | None = None,
11891204
entities: Sequence[str] | None = None,
11901205
memory_type: str | None = None,
@@ -1258,6 +1273,7 @@ async def search_memory_tool(
12581273

12591274
results = await self.search_long_term_memory(
12601275
text=query,
1276+
search_mode=search_mode,
12611277
topics=topics_filter,
12621278
entities=entities_filter,
12631279
memory_type=memory_type_filter,
@@ -1281,9 +1297,13 @@ async def search_memory_tool(
12811297
"created_at": memory.created_at.isoformat()
12821298
if memory.created_at
12831299
else None,
1284-
"relevance_score": 1.0 - memory.dist
1285-
if hasattr(memory, "dist") and memory.dist is not None
1286-
else None,
1300+
"relevance_score": memory.score
1301+
if hasattr(memory, "score") and memory.score is not None
1302+
else (
1303+
1.0 - memory.dist
1304+
if hasattr(memory, "dist") and memory.dist is not None
1305+
else None
1306+
),
12871307
}
12881308
)
12891309

agent-memory-client/agent_memory_client/models.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,22 @@ class MemoryTypeEnum(str, Enum):
5656
MESSAGE = "message"
5757

5858

59+
class SearchModeEnum(str, Enum):
60+
"""Enum for supported search strategies."""
61+
62+
SEMANTIC = "semantic"
63+
KEYWORD = "keyword"
64+
HYBRID = "hybrid"
65+
66+
67+
class SearchScoreTypeEnum(str, Enum):
68+
"""Enum describing how the normalized score field was produced."""
69+
70+
SEMANTIC = "semantic"
71+
KEYWORD = "keyword"
72+
HYBRID = "hybrid"
73+
74+
5975
class MemoryStrategyConfig(BaseModel):
6076
"""Configuration for memory extraction strategy."""
6177

@@ -352,6 +368,14 @@ class MemoryRecordResult(MemoryRecord):
352368
"""Result from a memory search"""
353369

354370
dist: float
371+
score: float | None = Field(
372+
default=None,
373+
description="Normalized relevance score for the selected search mode (0-1)",
374+
)
375+
score_type: SearchScoreTypeEnum | None = Field(
376+
default=None,
377+
description="Search mode used to produce the normalized score",
378+
)
355379

356380

357381
class RecencyConfig(BaseModel):

agent-memory-client/tests/test_client.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
MemoryRecordResults,
2222
MemoryTypeEnum,
2323
RecencyConfig,
24+
SearchModeEnum,
2425
WorkingMemoryResponse,
2526
)
2627

@@ -345,6 +346,31 @@ async def test_recency_config_descriptive_parameters(self, enhanced_test_client)
345346
assert body["recency_half_life_created_days"] == 30
346347
assert body["server_side_recency"] is True
347348

349+
@pytest.mark.asyncio
350+
async def test_search_mode_parameters_are_sent_to_api(self, enhanced_test_client):
351+
"""Test that keyword/hybrid search parameters are sent to the API."""
352+
with patch.object(enhanced_test_client._client, "post") as mock_post:
353+
mock_response = Mock()
354+
mock_response.raise_for_status.return_value = None
355+
mock_response.json.return_value = MemoryRecordResults(
356+
total=0, memories=[], next_offset=None
357+
).model_dump()
358+
mock_post.return_value = mock_response
359+
360+
await enhanced_test_client.search_long_term_memory(
361+
text="alpha beta",
362+
search_mode=SearchModeEnum.HYBRID,
363+
hybrid_alpha=0.55,
364+
text_scorer="BM25",
365+
limit=5,
366+
)
367+
368+
_, kwargs = mock_post.call_args
369+
body = kwargs["json"]
370+
assert body["search_mode"] == "hybrid"
371+
assert body["hybrid_alpha"] == 0.55
372+
assert body["text_scorer"] == "BM25"
373+
348374

349375
class TestClientSideValidation:
350376
"""Tests for client-side validation methods."""

agent_memory_server/api.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -645,7 +645,7 @@ async def search_long_term_memory(
645645
current_user: UserInfo = Depends(get_current_user),
646646
):
647647
"""
648-
Run a semantic search on long-term memory with filtering options.
648+
Run a long-term memory search with semantic, keyword, or hybrid options.
649649
650650
Args:
651651
payload: Search payload with filter objects for precise queries
@@ -663,6 +663,9 @@ async def search_long_term_memory(
663663
logger.debug(f"Long-term search filters: {filters}")
664664

665665
kwargs = {
666+
"search_mode": payload.search_mode,
667+
"hybrid_alpha": payload.hybrid_alpha,
668+
"text_scorer": payload.text_scorer,
666669
"distance_threshold": payload.distance_threshold,
667670
"limit": payload.limit,
668671
"offset": payload.offset,

0 commit comments

Comments
 (0)