@@ -3,6 +3,7 @@ package com.embabel.agent.rag.neo.drivine
33import com.embabel.agent.api.common.Embedding
44import com.embabel.agent.rag.ingestion.RetrievableEnhancer
55import com.embabel.agent.rag.model.*
6+ import com.embabel.agent.rag.service.EntitySearch
67import com.embabel.agent.rag.service.RagRequest
78import com.embabel.agent.rag.service.support.FunctionRagFacet
89import com.embabel.agent.rag.service.support.RagFacet
@@ -13,12 +14,17 @@ import com.embabel.agent.rag.store.ChunkingContentElementRepository
1314import com.embabel.agent.rag.store.DocumentDeletionResult
1415import com.embabel.common.ai.model.DefaultModelSelectionCriteria
1516import com.embabel.common.ai.model.ModelProvider
17+ import com.embabel.common.core.types.SimilarityCutoff
1618import com.embabel.common.core.types.SimilarityResult
19+ import com.embabel.common.core.types.SimpleSimilaritySearchResult
1720import org.drivine.manager.PersistenceManager
1821import org.drivine.query.QuerySpecification
1922import org.slf4j.LoggerFactory
2023import org.springframework.beans.factory.annotation.Qualifier
2124import org.springframework.stereotype.Service
25+ import org.springframework.transaction.PlatformTransactionManager
26+ import org.springframework.transaction.TransactionDefinition
27+ import org.springframework.transaction.support.TransactionTemplate
2228import kotlin.collections.get
2329
2430@Service
@@ -28,6 +34,7 @@ class DrivineStore(
2834 val properties : NeoRagServiceProperties ,
2935 private val cypherSearch : CypherSearch ,
3036 modelProvider : ModelProvider ,
37+ platformTransactionManager : PlatformTransactionManager ,
3138) : AbstractChunkingContentElementRepository(properties), ChunkingContentElementRepository, RagFacetProvider {
3239
3340 private val logger = LoggerFactory .getLogger(DrivineStore ::class .java)
@@ -193,26 +200,25 @@ class DrivineStore(
193200 }
194201
195202 fun search (ragRequest : RagRequest ): RagFacetResults <Retrievable > {
196- // val embedding = embeddingService.model.embed(ragRequest.query)
197- // val allResults = mutableListOf<SimilarityResult<out Retrievable>>()
198- // if (ragRequest.contentElementSearch.types.contains(Chunk::class.java)) {
199- // allResults += safelyExecuteInTransaction { chunkSearch(ragRequest, embedding) }
200- // } else {
201- // logger.info("No chunk search specified, skipping chunk search")
202- // }
203- //
204- // if (ragRequest.entitySearch != null) {
205- // allResults += safelyExecuteInTransaction { entitySearch(ragRequest, embedding) }
206- // } else {
207- // logger.info("No entity search specified, skipping entity search")
208- // }
209- //
210- // // TODO should reward multiple matches
211- // val mergedResults: List<SimilarityResult<out Retrievable>> = allResults
212- // .distinctBy { it.match.id }
213- // .sortedByDescending { it.score }
214- // .take(ragRequest.topK)
215- val mergedResults: List <SimilarityResult <out Retrievable >> = TODO ()
203+ val embedding = embeddingService.model.embed(ragRequest.query)
204+ val allResults = mutableListOf<SimilarityResult <out Retrievable >>()
205+ if (ragRequest.contentElementSearch.types.contains(Chunk ::class .java)) {
206+ allResults + = safelyExecuteInTransaction { chunkSearch(ragRequest, embedding) }
207+ } else {
208+ logger.info(" No chunk search specified, skipping chunk search" )
209+ }
210+
211+ if (ragRequest.entitySearch != null ) {
212+ allResults + = safelyExecuteInTransaction { entitySearch(ragRequest, embedding) }
213+ } else {
214+ logger.info(" No entity search specified, skipping entity search" )
215+ }
216+
217+ // TODO should reward multiple matches
218+ val mergedResults: List <SimilarityResult <out Retrievable >> = allResults
219+ .distinctBy { it.match.id }
220+ .sortedByDescending { it.score }
221+ .take(ragRequest.topK)
216222 return RagFacetResults (
217223 facetName = this .name,
218224 results = mergedResults,
@@ -276,4 +282,183 @@ class DrivineStore(
276282 throw RuntimeException (" Don't know how to map: $labels " )
277283 }
278284
285+ private val readonlyTransactionTemplate = TransactionTemplate (platformTransactionManager).apply {
286+ isReadOnly = true
287+ propagationBehavior = TransactionDefinition .PROPAGATION_REQUIRED
288+ }
289+
290+ private fun safelyExecuteInTransaction (block : () -> List <SimilarityResult <out Retrievable >>): List <SimilarityResult <out Retrievable >> {
291+ return try {
292+ readonlyTransactionTemplate.execute { block() } as List <SimilarityResult <out Retrievable >>
293+ } catch (e: Exception ) {
294+ logger.error(" Error during RAG search transaction" , e)
295+ emptyList()
296+ }
297+ }
298+
299+ private fun chunkSearch (
300+ ragRequest : RagRequest ,
301+ embedding : Embedding ,
302+ ): List <SimilarityResult <out Chunk >> {
303+ val chunkSimilarityResults = cypherSearch.chunkSimilaritySearch(
304+ " Chunk similarity search" ,
305+ query = " chunk_vector_search" ,
306+ params = commonParameters(ragRequest) + mapOf (
307+ " vectorIndex" to properties.contentElementIndex,
308+ " queryVector" to embedding,
309+ ),
310+ logger = logger,
311+ )
312+ logger.info(" {} chunk similarity results for query '{}'" , chunkSimilarityResults.size, ragRequest.query)
313+
314+ val chunkFullTextResults = cypherSearch.chunkFullTextSearch(
315+ purpose = " Chunk full text search" ,
316+ query = " chunk_fulltext_search" ,
317+ params = commonParameters(ragRequest) + mapOf (
318+ " fulltextIndex" to properties.contentElementFullTextIndex,
319+ " searchText" to " \" ${ragRequest.query} \" " ,
320+ ),
321+ logger = logger,
322+ )
323+ logger.info(" {} chunk full-text results for query '{}'" , chunkFullTextResults.size, ragRequest.query)
324+ return chunkSimilarityResults + chunkFullTextResults
325+ }
326+
327+ private fun entitySearch (
328+ ragRequest : RagRequest ,
329+ embedding : FloatArray ,
330+ ): List <SimilarityResult <out Retrievable >> {
331+ val allEntityResults = mutableListOf<SimilarityResult <out Retrievable >>()
332+ val labels = ragRequest.entitySearch?.labels ? : error(" No entity search specified" )
333+ val entityResults = entityVectorSearch(
334+ ragRequest,
335+ embedding,
336+ labels,
337+ )
338+ allEntityResults + = entityResults
339+ logger.info(" {} entity vector results for query '{}'" , entityResults.size, ragRequest.query)
340+ val entityFullTextResults = cypherSearch.entityFullTextSearch(
341+ purpose = " Entity full text search" ,
342+ query = " entity_fulltext_search" ,
343+ params = commonParameters(ragRequest) + mapOf (
344+ " fulltextIndex" to properties.entityFullTextIndex,
345+ " searchText" to ragRequest.query,
346+ " labels" to labels,
347+ ),
348+ logger = logger,
349+ )
350+ logger.info(" {} entity full-text results for query '{}'" , entityFullTextResults.size, ragRequest.query)
351+ allEntityResults + = entityFullTextResults
352+
353+ if (ragRequest.entitySearch?.generateQueries == true ) {
354+ val cypherResults =
355+ generateAndExecuteCypher(ragRequest, ragRequest.entitySearch!! ).also { cypherResults ->
356+ logger.info(" {} Cypher results for query '{}'" , cypherResults.size, ragRequest.query)
357+ }
358+ allEntityResults + = cypherResults
359+ } else {
360+ logger.info(" No query generation specified, skipping Cypher generation and execution" )
361+ }
362+ logger.info(" {} total entity results for query '{}'" , entityFullTextResults.size, ragRequest.query)
363+ return allEntityResults
364+ }
365+
366+ fun entityVectorSearch (
367+ request : SimilarityCutoff ,
368+ embedding : FloatArray ,
369+ labels : Set <String >,
370+ ): List <SimilarityResult <out EntityData >> {
371+ return cypherSearch.entityDataSimilaritySearch(
372+ purpose = " Mapped entity search" ,
373+ query = " entity_vector_search" ,
374+ params = commonParameters(request) + mapOf (
375+ " index" to properties.entityIndex,
376+ " queryVector" to embedding,
377+ " labels" to labels,
378+ ),
379+ logger,
380+ )
381+ }
382+
383+ private fun generateAndExecuteCypher (
384+ request : RagRequest ,
385+ entitySearch : EntitySearch ,
386+ ): List <SimilarityResult <out Retrievable >> {
387+ TODO (" Not yet implemented" )
388+ // val schema = schemaResolver.getSchema(entitySearch)
389+ // if (schema == null) {
390+ // logger.info("No schema found for entity search {}, skipping Cypher execution", entitySearch)
391+ // return emptyList()
392+ // }
393+ //
394+ // val cypherRagQueryGenerator = SchemaDrivenCypherRagQueryGenerator(
395+ // modelProvider,
396+ // schema,
397+ // )
398+ // val cypher = cypherRagQueryGenerator.generateQuery(request = request)
399+ // logger.info("Generated Cypher query: $cypher")
400+ //
401+ // val cypherResults = readonlyTransactionTemplate.execute {
402+ // executeGeneratedCypher(cypher)
403+ // } ?: Result.failure(
404+ // IllegalStateException("Transaction failed or returned null while executing Cypher query: $cypher")
405+ // )
406+ // if (cypherResults.isSuccess) {
407+ // val results = cypherResults.getOrThrow()
408+ // if (results.isNotEmpty()) {
409+ // logger.info("Cypher query executed successfully, results: {}", results)
410+ // return results.map {
411+ // // Most similar as we found them by a query
412+ // SimpleSimilaritySearchResult(
413+ // it,
414+ // score = 1.0,
415+ // )
416+ // }
417+ // }
418+ // }
419+ // return emptyList()
420+ }
421+
422+ // /**
423+ // * Execute generate Cypher query, being sure to handle exceptions gracefully.
424+ // */
425+ // private fun executeGeneratedCypher(
426+ // query: CypherQuery,
427+ // ): Result<List<EntityData>> {
428+ // TODO("Not yet implemented")
429+ // try {
430+ // return Result.success(
431+ // ogmCypherSearch.queryForEntities(
432+ // purpose = "cypherGeneratedQuery",
433+ // query = query.query
434+ // )
435+ // )
436+ // } catch (e: Exception) {
437+ // logger.error("Error executing generated query: $query", e)
438+ // return Result.failure(e)
439+ // }
440+ // }
441+
442+ private fun createVectorIndex (
443+ name : String ,
444+ on : String ,
445+ ) {
446+ val statement = """
447+ CREATE VECTOR INDEX `$name ` IF NOT EXISTS
448+ FOR (n:$on ) ON (n.embedding)
449+ OPTIONS {indexConfig: {
450+ `vector.dimensions`: ${embeddingService.model.dimensions()} ,
451+ `vector.similarity_function`: 'cosine'
452+ }}"""
453+
454+ persistenceManager.execute(QuerySpecification .withStatement(statement))
455+
456+ }
457+
458+ private fun commonParameters (request : SimilarityCutoff ) = mapOf (
459+ " topK" to request.topK,
460+ " similarityThreshold" to request.similarityThreshold,
461+ )
462+
463+
279464}
0 commit comments