Skip to content

Commit

Permalink
Now doing solr queries over dense vectors. Almost ready to start crea…
Browse files Browse the repository at this point in the history
…ting the dense vector queries in the grpc client. Slowly building up to this.
  • Loading branch information
krickert committed Sep 26, 2024
1 parent 71894d5 commit c78541e
Show file tree
Hide file tree
Showing 5 changed files with 274 additions and 15 deletions.
4 changes: 4 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,10 @@
<artifactId>solr</artifactId>
<version>1.20.1</version>
</dependency>
<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java-util</artifactId>
</dependency>
</dependencies>
<build>
<plugins>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package com.krickert.search.api;

import com.krickert.search.api.solr.ProtobufToSolrDocument;
import com.krickert.search.api.solr.SolrHelper;
import com.krickert.search.api.solr.SolrTest;
import com.krickert.search.model.pipe.PipeDocument;
import com.krickert.search.model.test.util.TestDataHelper;
import com.krickert.search.service.EmbeddingServiceGrpc;
import com.krickert.search.service.EmbeddingsVectorReply;
import com.krickert.search.service.EmbeddingsVectorRequest;
Expand All @@ -12,6 +16,7 @@
import jakarta.inject.Inject;
import jakarta.inject.Singleton;
import org.apache.solr.client.solrj.SolrQuery;
import org.apache.solr.client.solrj.SolrServerException;
import org.apache.solr.client.solrj.response.QueryResponse;
import org.apache.solr.client.solrj.response.UpdateResponse;
import org.apache.solr.common.SolrDocumentList;
Expand All @@ -23,11 +28,13 @@
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.utility.DockerImageName;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static com.krickert.search.api.solr.SolrHelper.buildVectorQuery;
import static org.junit.jupiter.api.Assertions.*;

@TestInstance(TestInstance.Lifecycle.PER_CLASS)
@MicronautTest(environments = "test") // Ensure the correct environment is used
Expand All @@ -52,6 +59,9 @@ EmbeddingServiceGrpc.EmbeddingServiceBlockingStub reactiveStub(
private static String vectorizerHost;
private static Integer vectorizerPort;

@Inject
ProtobufToSolrDocument protobufToSolrDocument;

@Inject
ApplicationContext context;

Expand Down Expand Up @@ -270,20 +280,12 @@ public void testDenseVectorSearch() throws Exception {
assertNotNull(queryVectorReply);
assertEquals(384, queryVectorReply.getEmbeddingsList().size());

// Create the dense vector query
StringBuilder vectorQueryBuilder = new StringBuilder();
vectorQueryBuilder.append("{!knn f=title-vector topK=10}[");
for (int i = 0; i < queryVectorReply.getEmbeddingsList().size(); i++) {
vectorQueryBuilder.append(queryVectorReply.getEmbeddingsList().get(i));
if (i < queryVectorReply.getEmbeddingsList().size() - 1) {
vectorQueryBuilder.append(",");
}
}
vectorQueryBuilder.append("]");
// Create the dense vector query using the utility function
String vectorQuery = buildVectorQuery("title-vector", queryVectorReply.getEmbeddingsList(), 10);

// Execute the dense vector search
SolrQuery solrQuery = new SolrQuery();
solrQuery.setQuery(vectorQueryBuilder.toString());
solrQuery.setQuery(vectorQuery);
QueryResponse queryResponse = solrClient.query(DEFAULT_COLLECTION, solrQuery);

// Validate the query response
Expand All @@ -292,6 +294,85 @@ public void testDenseVectorSearch() throws Exception {
assertEquals("Test Title 1", documents.get(0).getFieldValue("title")); // Assuming the closest match is returned first
}

@Test
public void sampleWikiDocumentDenseVectorSearchTest() {
EmbeddingServiceGrpc.EmbeddingServiceBlockingStub gRPCClient = context.getBean(EmbeddingServiceGrpc.EmbeddingServiceBlockingStub.class);
Collection<PipeDocument> docs = TestDataHelper.getFewHunderedPipeDocuments();
Collection<SolrInputDocument> solrDocs = new ArrayList<>();

for (PipeDocument doc : docs) {
SolrInputDocument inputDocument = protobufToSolrDocument.convertProtobufToSolrDocument(doc);

inputDocument.addField("title-vector", gRPCClient.createEmbeddingsVector(EmbeddingsVectorRequest.newBuilder().setText(doc.getTitle()).build()).getEmbeddingsList());
inputDocument.addField("body-vector", gRPCClient.createEmbeddingsVector(EmbeddingsVectorRequest.newBuilder().setText(doc.getBody()).build()).getEmbeddingsList());

solrDocs.add(inputDocument);
}

try {
UpdateResponse updateResponse1 = solrClient.add(DEFAULT_COLLECTION, solrDocs);
solrClient.commit(DEFAULT_COLLECTION);

} catch (SolrServerException | IOException e) {
fail(e);
}

// The target query text for KNN search
String queryText = "maintaining computers in large organizations";

// Generate embeddings for the query text
EmbeddingsVectorReply titleQueryVector = gRPCClient.createEmbeddingsVector(EmbeddingsVectorRequest.newBuilder().setText(queryText).build());
EmbeddingsVectorReply bodyQueryVector = gRPCClient.createEmbeddingsVector(EmbeddingsVectorRequest.newBuilder().setText(queryText).build());

// Confirm that the embeddings are generated correctly
assertNotNull(titleQueryVector);
assertEquals(384, titleQueryVector.getEmbeddingsList().size());
assertNotNull(bodyQueryVector);
assertEquals(384, bodyQueryVector.getEmbeddingsList().size());

// Create vector queries using utility function
String titleVectorQuery = SolrHelper.buildVectorQuery("title-vector", titleQueryVector.getEmbeddingsList(), 30);
String bodyVectorQuery = SolrHelper.buildVectorQuery("body-vector", bodyQueryVector.getEmbeddingsList(), 30);

// Execute KNN search for title vector
SolrQuery solrTitleQuery = new SolrQuery();
solrTitleQuery.setQuery(titleVectorQuery);
solrTitleQuery.setRows(30); // Ensure Solr returns 30 results
QueryResponse titleQueryResponse = null;
try {
titleQueryResponse = solrClient.query(DEFAULT_COLLECTION, solrTitleQuery);
} catch (SolrServerException | IOException e) {
fail(e);
}

// Execute KNN search for body vector
SolrQuery solrBodyQuery = new SolrQuery();
solrBodyQuery.setQuery(bodyVectorQuery);
solrBodyQuery.setRows(30); // Ensure Solr returns 30 results
QueryResponse bodyQueryResponse = null;
try {
bodyQueryResponse = solrClient.query(DEFAULT_COLLECTION, solrBodyQuery);
} catch (SolrServerException | IOException e) {
fail(e);
}

// Validate the query responses
assertNotNull(titleQueryResponse);
assertNotNull(bodyQueryResponse);

SolrDocumentList titleDocuments = titleQueryResponse.getResults();
SolrDocumentList bodyDocuments = bodyQueryResponse.getResults();

assertEquals(30, titleDocuments.size()); // Ensure that 30 documents are returned
assertEquals(30, bodyDocuments.size()); // Ensure that 30 documents are returned

// Log results for debugging purposes (optional)
log.info("Title KNN Search Results:");
titleDocuments.forEach(doc -> log.info("Doc ID: {}, Title: {}", doc.getFieldValue("id"), doc.getFieldValue("title")));

log.info("Body KNN Search Results:");
bodyDocuments.forEach(doc -> log.info("Doc ID: {}, TItle: {}", doc.getFieldValue("id"), doc.getFieldValue("title")));

// Additional assertions can be added based on expected results
}
}
156 changes: 156 additions & 0 deletions src/test/java/com/krickert/search/api/solr/ProtobufToSolrDocument.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
package com.krickert.search.api.solr;

import com.google.protobuf.*;
import com.google.protobuf.util.JsonFormat;
import jakarta.inject.Singleton;
import org.apache.solr.common.SolrInputDocument;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Date;
import java.util.List;
import java.util.Map;

@Singleton
public class ProtobufToSolrDocument {
private static final Logger log = LoggerFactory.getLogger(ProtobufToSolrDocument.class);

public ProtobufToSolrDocument() {
log.info("created ProtobufToSolrDocument");
}

public SolrInputDocument convertProtobufToSolrDocument(Message protobuf) {
try {
log.debug(JsonFormat.printer().print(protobuf));
} catch (InvalidProtocolBufferException e) {
throw new RuntimeException(e);
}
SolrInputDocument solrDoc = new SolrInputDocument();
addFieldsToSolrDoc(protobuf, solrDoc, "");
return solrDoc;
}

private void addFieldsToSolrDoc(Message message, SolrInputDocument solrDoc, String prefix) {
Map<Descriptors.FieldDescriptor, Object> allFields = message.getAllFields();
for (Map.Entry<Descriptors.FieldDescriptor, Object> entry : allFields.entrySet()) {
handleField(solrDoc, prefix, entry);
}
}

private void handleField(SolrInputDocument solrDoc, String prefix, Map.Entry<Descriptors.FieldDescriptor, Object> entry) {
String fieldName = prefix.isEmpty() ? entry.getKey().getName() : prefix + "_" + entry.getKey().getName();
if (entry.getValue() instanceof Message) {
handleMessageField(solrDoc, entry, fieldName);
} else if (entry.getKey().isMapField()) {
handleMapField(solrDoc, entry, fieldName);
} else if (entry.getKey().isRepeated()) {
handleRepeatedField(solrDoc, entry, fieldName);
} else {
solrDoc.addField(fieldName, entry.getValue());
}
}

private void handleMessageField(SolrInputDocument solrDoc, Map.Entry<Descriptors.FieldDescriptor, Object> entry, String fieldName) {
if (entry.getValue() instanceof Struct) {
extractFieldsFromStruct((Struct) entry.getValue(), solrDoc, fieldName);
} else if (entry.getValue() instanceof Timestamp timestamp) {
handleTimestampType(solrDoc, timestamp, fieldName);
} else if (entry.getValue() instanceof Duration duration) {
handleDurationType(solrDoc, duration, fieldName);
} else if (entry.getValue() instanceof BytesValue bytesValue) {
handleBytesType(solrDoc, bytesValue, fieldName);
} else if (entry.getValue() instanceof FloatValue floatValue) {
handleFloatType(solrDoc, floatValue, fieldName);
} else if (entry.getValue() instanceof Empty) {
handleEmptyType(solrDoc, fieldName);
} else if (entry.getValue() instanceof FieldMask fieldMask) {
handleFieldMaskType(solrDoc, fieldMask, fieldName);
} else {
addFieldsToSolrDoc((Message) entry.getValue(), solrDoc, fieldName);
}
}

private static void handleRepeatedField(SolrInputDocument solrDoc, Map.Entry<Descriptors.FieldDescriptor, Object> entry, String fieldName) {
@SuppressWarnings("unchecked") List<Object> listValue = (List<Object>) entry.getValue();
for (Object item : listValue) {
solrDoc.addField(fieldName, item);
}
}

private static void handleMapField(SolrInputDocument solrDoc, Map.Entry<Descriptors.FieldDescriptor, Object> entry, String fieldName) {
@SuppressWarnings("unchecked") Map<Object, Object> mapValue = (Map<Object, Object>) entry.getValue();
for (Map.Entry<Object, Object> mapEntry : mapValue.entrySet()) {
solrDoc.addField(fieldName + "_" + mapEntry.getKey(), mapEntry.getValue());
}
}

private static void handleFieldMaskType(SolrInputDocument solrDoc, FieldMask fieldMask, String fieldName) {
// Convert paths in FieldMask to a comma-separated string
String paths = String.join(", ", fieldMask.getPathsList());
solrDoc.addField(fieldName, paths);
}

private static void handleEmptyType(SolrInputDocument solrDoc, String fieldName) {
// No actual data to add, but we can acknowledge its existence.
solrDoc.addField(fieldName, "__EMPTY__");
}

private static void handleFloatType(SolrInputDocument solrDoc, FloatValue floatValue, String fieldName) {
// Convert protobuf FloatValue to a Java float
float javaFloat = floatValue.getValue();
solrDoc.addField(fieldName, javaFloat);
}

private static void handleBytesType(SolrInputDocument solrDoc, BytesValue bytesValue, String fieldName) {
// Convert protobuf BytesValue to String
String byteString = bytesValue.getValue().toStringUtf8();
solrDoc.addField(fieldName, byteString);
}

private static void handleDurationType(SolrInputDocument solrDoc, Duration duration, String fieldName) {
// Convert protobuf Duration to java.time.Duration
java.time.Duration javaDuration = java.time.Duration.ofSeconds(duration.getSeconds(), duration.getNanos());
solrDoc.addField(fieldName, javaDuration.toString());
}

private static void handleTimestampType(SolrInputDocument solrDoc, Timestamp timestamp, String fieldName) {
// Handle Timestamp fields
// Convert to java.util.Date then add to solrDoc
long milliseconds = timestamp.getSeconds() * 1000L + timestamp.getNanos() / 1000000;
Date javaDate = new Date(milliseconds);
solrDoc.addField(fieldName, javaDate);
}

private void extractFieldsFromStruct(Struct struct, SolrInputDocument solrDoc, String prefix) {
Map<String, Value> fields = struct.getFieldsMap();

for (Map.Entry<String, Value> entry : fields.entrySet()) {
String newFieldKey = prefix + "_" + entry.getKey();
Value.KindCase type = entry.getValue().getKindCase();

switch (type) {
case BOOL_VALUE:
solrDoc.addField(newFieldKey, entry.getValue().getBoolValue());
break;
case NUMBER_VALUE:
solrDoc.addField(newFieldKey, entry.getValue().getNumberValue());
break;
case STRING_VALUE:
solrDoc.addField(newFieldKey, entry.getValue().getStringValue());
break;
case LIST_VALUE:
ListValue listValue = entry.getValue().getListValue();
for (Value listItem : listValue.getValuesList()) {
solrDoc.addField(newFieldKey, listItem.toString());
}
break;
case STRUCT_VALUE:
extractFieldsFromStruct(entry.getValue().getStructValue(), solrDoc, newFieldKey);
break;
case NULL_VALUE:
solrDoc.addField(newFieldKey, null);
break;
}
}
}
}
18 changes: 18 additions & 0 deletions src/test/java/com/krickert/search/api/solr/SolrHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import org.apache.solr.client.solrj.response.schema.SchemaResponse;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

public class SolrHelper {

Expand Down Expand Up @@ -46,4 +48,20 @@ private static void addVectorFieldType(SolrClient solrClient, String collectionN
throw new RuntimeException("Failed to add vector field type: " + response.getStatus() + ", " + response.getResponse());
}
}

/**
* Builds a vector query string for Solr using the provided field and vector embeddings.
*
* @param field The Solr field to search against.
* @param embeddings The vector embeddings.
* @param topK The number of top results to fetch.
* @return The vector query string.
*/
public static String buildVectorQuery(String field, List<Float> embeddings, int topK) {
String vectorString = embeddings.stream()
.map(String::valueOf)
.collect(Collectors.joining(","));

return String.format("{!knn f=%s topK=%d}[%s]", field, topK, vectorString);
}
}
2 changes: 1 addition & 1 deletion src/test/java/com/krickert/search/api/solr/SolrTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public void beforeEach() throws Exception {
log.info("Creating temporary collection: {}", DEFAULT_COLLECTION);
// Define schema for the collection
addField("title", "string", false);
addField("body", "string", false);
addField("body", "text_general", false);
SolrHelper.addDenseVectorField(solrClient, "documents", "title-vector", 384);
SolrHelper.addDenseVectorField(solrClient, "documents", "body-vector", 384);

Expand Down

0 comments on commit c78541e

Please sign in to comment.