From 4439002fe5b092116f9b0ba9d0bc43bf543dfcd5 Mon Sep 17 00:00:00 2001 From: mck Date: Thu, 25 Apr 2024 12:53:29 +0200 Subject: [PATCH] WORK IN PROGRESS Implement manual reranking --- pom.xml | 5 +++++ .../com/datastax/ai/agent/vector/AiAgentVector.java | 13 ++++++++++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/pom.xml b/pom.xml index 4f20e63..53ef84c 100644 --- a/pom.xml +++ b/pom.xml @@ -66,6 +66,11 @@ java-uuid-generator 4.0.1 + + io.github.jbellis + jvector + 3.0.0-beta.3 + org.springframework.boot diff --git a/src/main/java/com/datastax/ai/agent/vector/AiAgentVector.java b/src/main/java/com/datastax/ai/agent/vector/AiAgentVector.java index 47983d5..71896b7 100644 --- a/src/main/java/com/datastax/ai/agent/vector/AiAgentVector.java +++ b/src/main/java/com/datastax/ai/agent/vector/AiAgentVector.java @@ -21,9 +21,10 @@ import com.datastax.ai.agent.base.AiAgent; import com.datastax.ai.agent.base.AiAgentDelegator; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import io.github.jbellis.jvector.vector.VectorizationProvider; +import io.github.jbellis.jvector.vector.types.VectorFloat; +import io.github.jbellis.jvector.vector.types.VectorTypeSupport; +import org.apache.commons.lang3.ArrayUtils; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; @@ -35,6 +36,9 @@ public class AiAgentVector extends AiAgentDelegator { + + private static final VectorTypeSupport VTS = VectorizationProvider.getInstance().getVectorTypeSupport(); + private static final int CHAT_DOCUMENTS_SIZE = 3; private final CassandraVectorStore store; @@ -56,6 +60,9 @@ public Prompt createPrompt(UserMessage message, Map promptPropert promptProperties = promptProperties(promptProperties); // any re-ranking happens here + VectorFloat v0 = (VectorFloat) VTS.createFloatVector(ArrayUtils.toPrimitive( + (Float[])similarDocuments.get(0).getEmbedding().stream().map(Double::floatValue).toArray())); + promptProperties.put("documents", similarDocuments); return super.createPrompt(message, promptProperties);