Skip to content

Commit

Permalink
WORK IN PROGRESS Implement manual reranking
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelsembwever committed Apr 25, 2024
1 parent d7a2dae commit 4439002
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
5 changes: 5 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@
<artifactId>java-uuid-generator</artifactId>
<version>4.0.1</version>
</dependency>
<dependency>
<groupId>io.github.jbellis</groupId>
<artifactId>jvector</artifactId>
<version>3.0.0-beta.3</version>
</dependency>
<!-- test scope -->
<dependency>
<groupId>org.springframework.boot</groupId>
Expand Down
13 changes: 10 additions & 3 deletions src/main/java/com/datastax/ai/agent/vector/AiAgentVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@

import com.datastax.ai.agent.base.AiAgent;
import com.datastax.ai.agent.base.AiAgentDelegator;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import io.github.jbellis.jvector.vector.VectorizationProvider;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
import org.apache.commons.lang3.ArrayUtils;

import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
Expand All @@ -35,6 +36,9 @@

public class AiAgentVector extends AiAgentDelegator {


private static final VectorTypeSupport VTS = VectorizationProvider.getInstance().getVectorTypeSupport();

private static final int CHAT_DOCUMENTS_SIZE = 3;

private final CassandraVectorStore store;
Expand All @@ -56,6 +60,9 @@ public Prompt createPrompt(UserMessage message, Map<String,Object> promptPropert
promptProperties = promptProperties(promptProperties);

// any re-ranking happens here
VectorFloat<float[]> v0 = (VectorFloat<float[]>) VTS.createFloatVector(ArrayUtils.toPrimitive(
(Float[])similarDocuments.get(0).getEmbedding().stream().map(Double::floatValue).toArray()));


promptProperties.put("documents", similarDocuments);
return super.createPrompt(message, promptProperties);
Expand Down

0 comments on commit 4439002

Please sign in to comment.