Skip to content

Commit 72b923d

Browse files
committed
Wire up so we can successfully answer queries that do full text search.
1 parent aff113d commit 72b923d

File tree

10 files changed

+261
-56
lines changed

10 files changed

+261
-56
lines changed

src/.DS_Store

2 KB
Binary file not shown.

src/main/.DS_Store

0 Bytes
Binary file not shown.

src/main/kotlin/com/embabel/agent/rag/neo/drivine/DrivineCypherSearch.kt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,11 +188,14 @@ class DrivineCypherSearch(
188188
val loggerToUse = logger ?: this.logger
189189
val cypher = if (query.contains(" ")) query else queryResolver.resolve(query)!!
190190
loggerToUse.info("[{}] query\n\tparams: {}\n{}", purpose, params, cypher)
191+
192+
val myParams = params as Map<String, Any>
193+
191194
@Suppress("UNCHECKED_SCAST")
192195
val rows = persistenceManager.query(
193196
QuerySpecification
194197
.withStatement(cypher)
195-
.bind(params as Map<String, Any>)
198+
.bind(myParams)
196199
.transform(Map::class.java)
197200
) as List<Map<String, Any>>
198201

src/main/kotlin/com/embabel/agent/rag/neo/drivine/DrivineStore.kt

Lines changed: 205 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package com.embabel.agent.rag.neo.drivine
33
import com.embabel.agent.api.common.Embedding
44
import com.embabel.agent.rag.ingestion.RetrievableEnhancer
55
import com.embabel.agent.rag.model.*
6+
import com.embabel.agent.rag.service.EntitySearch
67
import com.embabel.agent.rag.service.RagRequest
78
import com.embabel.agent.rag.service.support.FunctionRagFacet
89
import com.embabel.agent.rag.service.support.RagFacet
@@ -13,12 +14,17 @@ import com.embabel.agent.rag.store.ChunkingContentElementRepository
1314
import com.embabel.agent.rag.store.DocumentDeletionResult
1415
import com.embabel.common.ai.model.DefaultModelSelectionCriteria
1516
import com.embabel.common.ai.model.ModelProvider
17+
import com.embabel.common.core.types.SimilarityCutoff
1618
import com.embabel.common.core.types.SimilarityResult
19+
import com.embabel.common.core.types.SimpleSimilaritySearchResult
1720
import org.drivine.manager.PersistenceManager
1821
import org.drivine.query.QuerySpecification
1922
import org.slf4j.LoggerFactory
2023
import org.springframework.beans.factory.annotation.Qualifier
2124
import org.springframework.stereotype.Service
25+
import org.springframework.transaction.PlatformTransactionManager
26+
import org.springframework.transaction.TransactionDefinition
27+
import org.springframework.transaction.support.TransactionTemplate
2228
import kotlin.collections.get
2329

2430
@Service
@@ -28,6 +34,7 @@ class DrivineStore(
2834
val properties: NeoRagServiceProperties,
2935
private val cypherSearch: CypherSearch,
3036
modelProvider: ModelProvider,
37+
platformTransactionManager: PlatformTransactionManager,
3138
) : AbstractChunkingContentElementRepository(properties), ChunkingContentElementRepository, RagFacetProvider {
3239

3340
private val logger = LoggerFactory.getLogger(DrivineStore::class.java)
@@ -193,26 +200,25 @@ class DrivineStore(
193200
}
194201

195202
fun search(ragRequest: RagRequest): RagFacetResults<Retrievable> {
196-
// val embedding = embeddingService.model.embed(ragRequest.query)
197-
// val allResults = mutableListOf<SimilarityResult<out Retrievable>>()
198-
// if (ragRequest.contentElementSearch.types.contains(Chunk::class.java)) {
199-
// allResults += safelyExecuteInTransaction { chunkSearch(ragRequest, embedding) }
200-
// } else {
201-
// logger.info("No chunk search specified, skipping chunk search")
202-
// }
203-
//
204-
// if (ragRequest.entitySearch != null) {
205-
// allResults += safelyExecuteInTransaction { entitySearch(ragRequest, embedding) }
206-
// } else {
207-
// logger.info("No entity search specified, skipping entity search")
208-
// }
209-
//
210-
// // TODO should reward multiple matches
211-
// val mergedResults: List<SimilarityResult<out Retrievable>> = allResults
212-
// .distinctBy { it.match.id }
213-
// .sortedByDescending { it.score }
214-
// .take(ragRequest.topK)
215-
val mergedResults: List<SimilarityResult<out Retrievable>> = TODO()
203+
val embedding = embeddingService.model.embed(ragRequest.query)
204+
val allResults = mutableListOf<SimilarityResult<out Retrievable>>()
205+
if (ragRequest.contentElementSearch.types.contains(Chunk::class.java)) {
206+
allResults += safelyExecuteInTransaction { chunkSearch(ragRequest, embedding) }
207+
} else {
208+
logger.info("No chunk search specified, skipping chunk search")
209+
}
210+
211+
if (ragRequest.entitySearch != null) {
212+
allResults += safelyExecuteInTransaction { entitySearch(ragRequest, embedding) }
213+
} else {
214+
logger.info("No entity search specified, skipping entity search")
215+
}
216+
217+
// TODO should reward multiple matches
218+
val mergedResults: List<SimilarityResult<out Retrievable>> = allResults
219+
.distinctBy { it.match.id }
220+
.sortedByDescending { it.score }
221+
.take(ragRequest.topK)
216222
return RagFacetResults(
217223
facetName = this.name,
218224
results = mergedResults,
@@ -276,4 +282,183 @@ class DrivineStore(
276282
throw RuntimeException("Don't know how to map: $labels")
277283
}
278284

285+
private val readonlyTransactionTemplate = TransactionTemplate(platformTransactionManager).apply {
286+
isReadOnly = true
287+
propagationBehavior = TransactionDefinition.PROPAGATION_REQUIRED
288+
}
289+
290+
private fun safelyExecuteInTransaction(block: () -> List<SimilarityResult<out Retrievable>>): List<SimilarityResult<out Retrievable>> {
291+
return try {
292+
readonlyTransactionTemplate.execute { block() } as List<SimilarityResult<out Retrievable>>
293+
} catch (e: Exception) {
294+
logger.error("Error during RAG search transaction", e)
295+
emptyList()
296+
}
297+
}
298+
299+
private fun chunkSearch(
300+
ragRequest: RagRequest,
301+
embedding: Embedding,
302+
): List<SimilarityResult<out Chunk>> {
303+
val chunkSimilarityResults = cypherSearch.chunkSimilaritySearch(
304+
"Chunk similarity search",
305+
query = "chunk_vector_search",
306+
params = commonParameters(ragRequest) + mapOf(
307+
"vectorIndex" to properties.contentElementIndex,
308+
"queryVector" to embedding,
309+
),
310+
logger = logger,
311+
)
312+
logger.info("{} chunk similarity results for query '{}'", chunkSimilarityResults.size, ragRequest.query)
313+
314+
val chunkFullTextResults = cypherSearch.chunkFullTextSearch(
315+
purpose = "Chunk full text search",
316+
query = "chunk_fulltext_search",
317+
params = commonParameters(ragRequest) + mapOf(
318+
"fulltextIndex" to properties.contentElementFullTextIndex,
319+
"searchText" to "\"${ragRequest.query}\"",
320+
),
321+
logger = logger,
322+
)
323+
logger.info("{} chunk full-text results for query '{}'", chunkFullTextResults.size, ragRequest.query)
324+
return chunkSimilarityResults + chunkFullTextResults
325+
}
326+
327+
private fun entitySearch(
328+
ragRequest: RagRequest,
329+
embedding: FloatArray,
330+
): List<SimilarityResult<out Retrievable>> {
331+
val allEntityResults = mutableListOf<SimilarityResult<out Retrievable>>()
332+
val labels = ragRequest.entitySearch?.labels ?: error("No entity search specified")
333+
val entityResults = entityVectorSearch(
334+
ragRequest,
335+
embedding,
336+
labels,
337+
)
338+
allEntityResults += entityResults
339+
logger.info("{} entity vector results for query '{}'", entityResults.size, ragRequest.query)
340+
val entityFullTextResults = cypherSearch.entityFullTextSearch(
341+
purpose = "Entity full text search",
342+
query = "entity_fulltext_search",
343+
params = commonParameters(ragRequest) + mapOf(
344+
"fulltextIndex" to properties.entityFullTextIndex,
345+
"searchText" to ragRequest.query,
346+
"labels" to labels,
347+
),
348+
logger = logger,
349+
)
350+
logger.info("{} entity full-text results for query '{}'", entityFullTextResults.size, ragRequest.query)
351+
allEntityResults += entityFullTextResults
352+
353+
if (ragRequest.entitySearch?.generateQueries == true) {
354+
val cypherResults =
355+
generateAndExecuteCypher(ragRequest, ragRequest.entitySearch!!).also { cypherResults ->
356+
logger.info("{} Cypher results for query '{}'", cypherResults.size, ragRequest.query)
357+
}
358+
allEntityResults += cypherResults
359+
} else {
360+
logger.info("No query generation specified, skipping Cypher generation and execution")
361+
}
362+
logger.info("{} total entity results for query '{}'", entityFullTextResults.size, ragRequest.query)
363+
return allEntityResults
364+
}
365+
366+
fun entityVectorSearch(
367+
request: SimilarityCutoff,
368+
embedding: FloatArray,
369+
labels: Set<String>,
370+
): List<SimilarityResult<out EntityData>> {
371+
return cypherSearch.entityDataSimilaritySearch(
372+
purpose = "Mapped entity search",
373+
query = "entity_vector_search",
374+
params = commonParameters(request) + mapOf(
375+
"index" to properties.entityIndex,
376+
"queryVector" to embedding,
377+
"labels" to labels,
378+
),
379+
logger,
380+
)
381+
}
382+
383+
private fun generateAndExecuteCypher(
384+
request: RagRequest,
385+
entitySearch: EntitySearch,
386+
): List<SimilarityResult<out Retrievable>> {
387+
TODO("Not yet implemented")
388+
// val schema = schemaResolver.getSchema(entitySearch)
389+
// if (schema == null) {
390+
// logger.info("No schema found for entity search {}, skipping Cypher execution", entitySearch)
391+
// return emptyList()
392+
// }
393+
//
394+
// val cypherRagQueryGenerator = SchemaDrivenCypherRagQueryGenerator(
395+
// modelProvider,
396+
// schema,
397+
// )
398+
// val cypher = cypherRagQueryGenerator.generateQuery(request = request)
399+
// logger.info("Generated Cypher query: $cypher")
400+
//
401+
// val cypherResults = readonlyTransactionTemplate.execute {
402+
// executeGeneratedCypher(cypher)
403+
// } ?: Result.failure(
404+
// IllegalStateException("Transaction failed or returned null while executing Cypher query: $cypher")
405+
// )
406+
// if (cypherResults.isSuccess) {
407+
// val results = cypherResults.getOrThrow()
408+
// if (results.isNotEmpty()) {
409+
// logger.info("Cypher query executed successfully, results: {}", results)
410+
// return results.map {
411+
// // Most similar as we found them by a query
412+
// SimpleSimilaritySearchResult(
413+
// it,
414+
// score = 1.0,
415+
// )
416+
// }
417+
// }
418+
// }
419+
// return emptyList()
420+
}
421+
422+
// /**
423+
// * Execute generate Cypher query, being sure to handle exceptions gracefully.
424+
// */
425+
// private fun executeGeneratedCypher(
426+
// query: CypherQuery,
427+
// ): Result<List<EntityData>> {
428+
// TODO("Not yet implemented")
429+
// try {
430+
// return Result.success(
431+
// ogmCypherSearch.queryForEntities(
432+
// purpose = "cypherGeneratedQuery",
433+
// query = query.query
434+
// )
435+
// )
436+
// } catch (e: Exception) {
437+
// logger.error("Error executing generated query: $query", e)
438+
// return Result.failure(e)
439+
// }
440+
// }
441+
442+
private fun createVectorIndex(
443+
name: String,
444+
on: String,
445+
) {
446+
val statement = """
447+
CREATE VECTOR INDEX `$name` IF NOT EXISTS
448+
FOR (n:$on) ON (n.embedding)
449+
OPTIONS {indexConfig: {
450+
`vector.dimensions`: ${embeddingService.model.dimensions()},
451+
`vector.similarity_function`: 'cosine'
452+
}}"""
453+
454+
persistenceManager.execute(QuerySpecification.withStatement(statement))
455+
456+
}
457+
458+
private fun commonParameters(request: SimilarityCutoff) = mapOf(
459+
"topK" to request.topK,
460+
"similarityThreshold" to request.similarityThreshold,
461+
)
462+
463+
279464
}

src/main/resources/cypher/chunk_fulltext_search.cypher

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@ UNWIND results AS result
77
WITH result.node AS chunk,
88
result.score / maxScore AS normalizedScore
99
WHERE normalizedScore >= $similarityThreshold
10-
RETURN chunk.text AS text,
11-
chunk.id AS id,
12-
normalizedScore AS score
13-
ORDER BY score DESC
14-
LIMIT $topK
10+
RETURN {
11+
text: chunk.text,
12+
id: chunk.id,
13+
score: normalizedScore
14+
} AS result
15+
ORDER BY result.score DESC
16+
LIMIT $topK
Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
CALL db.index.vector.queryNodes($vectorIndex, $topK, $queryVector)
22
YIELD node AS chunk, score
33
WHERE score >= $similarityThreshold
4-
RETURN chunk.text AS text, chunk.id AS id,
5-
score
6-
ORDER BY score DESC
4+
RETURN {
5+
text: chunk.text,
6+
id: chunk.id,
7+
score: score
8+
} AS result
9+
ORDER BY result.score DESC
Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
MATCH (chunk:$($chunkNodeName) {id: $basisId})
2-
CREATE (chunk)-[:HAS_ENTITY]->(e:$($entityLabels) {id: $id, name: $name, description: $description, createdDate: timestamp()})
2+
CREATE (chunk)-[:HAS_ENTITY]->(e:$($entityLabels) {
3+
id: $id,
4+
name: $name,
5+
description: $description,
6+
createdDate: timestamp()
7+
})
38
SET e += $properties,
49
e.lastModifiedDate = timestamp()
5-
RETURN e.id as id, COUNT(e) as nodesCreated
10+
RETURN {
11+
id: e.id,
12+
nodesCreated: COUNT(e)
13+
} AS result

src/main/resources/cypher/entity_fulltext_search.cypher

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,17 @@ WITH collect({node: m, score: score}) AS results, max(score) AS maxScore
66
UNWIND results AS result
77
WITH result.node AS match,
88
COALESCE(result.score / maxScore, 0.0) AS score,
9-
result.node.name as name,
10-
result.node.description as description,
11-
result.node.id AS id,
12-
labels(result.node) AS labels
13-
WHERE score >= $similarityThreshold
14-
RETURN
15-
COALESCE(name, '') as name,
16-
COALESCE(description, '') as description,
17-
COALESCE(id, '') as id,
18-
properties(match) AS properties,
19-
labels,
20-
score
21-
ORDER BY score DESC
9+
result.node.name AS name,
10+
result.node.description AS description,
11+
result.node.id AS id,
12+
labels(result.node) AS labels
13+
WHERE score >= $similarityThreshold
14+
RETURN {
15+
name: COALESCE(name, ''),
16+
description: COALESCE(description, ''),
17+
id: COALESCE(id, ''),
18+
properties: properties(match),
19+
labels: labels,
20+
score: score
21+
} AS result
22+
ORDER BY result.score DESC
Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
CALL db.index.vector.queryNodes($index, $topK, $queryVector)
22
YIELD node AS m, score
3-
WHERE score >= $similarityThreshold AND
4-
any(label IN labels(m) WHERE label IN $labels)
5-
RETURN properties(m) AS properties, m.name as name, m.description as description, m.id AS id, labels(m) AS labels,
6-
score
7-
ORDER BY score DESC
8-
3+
WHERE score >= $similarityThreshold
4+
AND any(label IN labels(m) WHERE label IN $labels)
5+
RETURN {
6+
properties: properties(m),
7+
name: COALESCE(m.name, ''),
8+
description: COALESCE(m.description, ''),
9+
id: COALESCE(m.id, ''),
10+
labels: labels(m),
11+
score: score
12+
} AS result
13+
ORDER BY result.score DESC
Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
MATCH (existing:$($labels))
2-
RETURN existing,
3-
labels(existing) AS labels,
4-
existing.name as name,
5-
existing.id AS id,
6-
existing.description AS description,
7-
LIMIT $limit;
2+
RETURN
3+
properties(existing) +
4+
{ labels: labels(existing) } AS result
5+
LIMIT $limit

0 commit comments

Comments
 (0)