Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import com.azure.cosmos.models.PartitionKeyDefinition;
import com.azure.cosmos.models.CosmosVectorIndexSpec;
import com.azure.cosmos.models.CosmosVectorIndexType;
import com.azure.cosmos.models.QuantizerType;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.slf4j.Logger;
Expand Down Expand Up @@ -79,29 +80,26 @@ public void afterClass() {

@Test(groups = {"emulator"}, timeOut = TIMEOUT*10000)
public void shouldCreateVectorEmbeddingPolicy() {
PartitionKeyDefinition partitionKeyDef = new PartitionKeyDefinition();
ArrayList<String> paths = new ArrayList<String>();
paths.add("/mypk");
partitionKeyDef.setPaths(paths);

CosmosContainerProperties collectionDefinition = new CosmosContainerProperties(UUID.randomUUID().toString(), partitionKeyDef);
ArrayList<String> paths = new ArrayList<>(Arrays.asList("/mypk"));
PartitionKeyDefinition partitionKeyDef = new PartitionKeyDefinition()
.setPaths(paths);

IndexingPolicy indexingPolicy = new IndexingPolicy();
indexingPolicy.setIndexingMode(IndexingMode.CONSISTENT);
ExcludedPath excludedPath = new ExcludedPath("/*");
indexingPolicy.setExcludedPaths(Collections.singletonList(excludedPath));

IncludedPath includedPath1 = new IncludedPath("/name/?");
IncludedPath includedPath2 = new IncludedPath("/description/?");
indexingPolicy.setIncludedPaths(ImmutableList.of(includedPath1, includedPath2));

indexingPolicy.setVectorIndexes(populateVectorIndexes());
IndexingPolicy indexingPolicy = new IndexingPolicy()
.setIndexingMode(IndexingMode.CONSISTENT)
.setExcludedPaths(Collections.singletonList(excludedPath))
.setIncludedPaths(ImmutableList.of(includedPath1, includedPath2))
.setVectorIndexes(populateVectorIndexes());

CosmosVectorEmbeddingPolicy cosmosVectorEmbeddingPolicy = new CosmosVectorEmbeddingPolicy();
cosmosVectorEmbeddingPolicy.setCosmosVectorEmbeddings(populateEmbeddings());

collectionDefinition.setIndexingPolicy(indexingPolicy);
collectionDefinition.setVectorEmbeddingPolicy(cosmosVectorEmbeddingPolicy);
CosmosContainerProperties collectionDefinition = new CosmosContainerProperties(UUID.randomUUID().toString(), partitionKeyDef)
.setIndexingPolicy(indexingPolicy)
.setVectorEmbeddingPolicy(cosmosVectorEmbeddingPolicy);

database.createContainer(collectionDefinition).block();
CosmosAsyncContainer createdCollection = database.getContainer(collectionDefinition.getId());
Expand Down Expand Up @@ -279,14 +277,30 @@ public void shouldValidateVectorEmbeddingPolicySerializationAndDeserialization()
validateVectorEmbeddingPolicy(cosmosVectorEmbeddingPolicy, expectedCosmosVectorEmbeddingPolicy);
}

@Test(groups = {"unit"}, timeOut = TIMEOUT)
public void shouldValidateVectorIndexesSerializationAndDeserialization() throws JsonProcessingException {
IndexingPolicy indexingPolicy = new IndexingPolicy();
indexingPolicy.setVectorIndexes(populateVectorIndexes());
List<CosmosVectorIndexSpec> expectedVectorIndexes = indexingPolicy.getVectorIndexes();

// Validate Vector Indexes Serialization
String actualVectorIndexesJSON = simpleObjectMapper.writeValueAsString(expectedVectorIndexes);
String expectedVectorIndexesJSON = getVectorIndexesAsString();
assertThat(actualVectorIndexesJSON).isEqualTo(expectedVectorIndexesJSON);

// Validate Vector Indexes Deserialization
List<CosmosVectorIndexSpec> actualVectorIndexes = Arrays.asList(simpleObjectMapper.readValue(actualVectorIndexesJSON, CosmosVectorIndexSpec[].class));
validateVectorIndexes(actualVectorIndexes, expectedVectorIndexes);
}

private void validateCollectionProperties(CosmosContainerProperties collectionDefinition, CosmosContainerProperties collectionProperties) {
assertThat(collectionProperties.getVectorEmbeddingPolicy()).isNotNull();
assertThat(collectionProperties.getVectorEmbeddingPolicy().getVectorEmbeddings()).isNotNull();
validateVectorEmbeddingPolicy(collectionProperties.getVectorEmbeddingPolicy(),
collectionDefinition.getVectorEmbeddingPolicy());

assertThat(collectionProperties.getIndexingPolicy().getVectorIndexes()).isNotNull();
validateVectorIndexes(collectionDefinition.getIndexingPolicy().getVectorIndexes(), collectionProperties.getIndexingPolicy().getVectorIndexes());
validateVectorIndexes(collectionProperties.getIndexingPolicy().getVectorIndexes(), collectionDefinition.getIndexingPolicy().getVectorIndexes());
}

private void validateVectorEmbeddingPolicy(CosmosVectorEmbeddingPolicy actual, CosmosVectorEmbeddingPolicy expected) {
Expand All @@ -302,69 +316,96 @@ private void validateVectorEmbeddingPolicy(CosmosVectorEmbeddingPolicy actual, C
}

private void validateVectorIndexes(List<CosmosVectorIndexSpec> actual, List<CosmosVectorIndexSpec> expected) {
assertThat(expected).hasSameSizeAs(actual);
for (int i = 0; i < expected.size(); i++) {
assertThat(expected.get(i).getPath()).isEqualTo(actual.get(i).getPath());
assertThat(expected.get(i).getType()).isEqualTo(actual.get(i).getType());
if (Objects.equals(expected.get(i).getType(), CosmosVectorIndexType.QUANTIZED_FLAT.toString()) ||
Objects.equals(expected.get(i).getType(), CosmosVectorIndexType.DISK_ANN.toString())) {
assertThat(expected.get(i).getQuantizationSizeInBytes()).isEqualTo(actual.get(i).getQuantizationSizeInBytes());
assertThat(expected.get(i).getVectorIndexShardKeys()).isEqualTo(actual.get(i).getVectorIndexShardKeys());
assertThat(actual).hasSameSizeAs(expected);
for (int i = 0; i < actual.size(); i++) {
assertThat(actual.get(i).getPath()).isEqualTo(expected.get(i).getPath());
assertThat(actual.get(i).getType()).isEqualTo(expected.get(i).getType());
if (Objects.equals(actual.get(i).getType(), CosmosVectorIndexType.QUANTIZED_FLAT.toString()) ||
Objects.equals(actual.get(i).getType(), CosmosVectorIndexType.DISK_ANN.toString())) {
assertThat(actual.get(i).getQuantizerType()).isEqualTo(expected.get(i).getQuantizerType());
assertThat(actual.get(i).getQuantizationSizeInBytes()).isEqualTo(expected.get(i).getQuantizationSizeInBytes());
assertThat(actual.get(i).getVectorIndexShardKeys()).isEqualTo(expected.get(i).getVectorIndexShardKeys());
}
if (Objects.equals(expected.get(i).getType(), CosmosVectorIndexType.DISK_ANN.toString())) {
assertThat(expected.get(i).getIndexingSearchListSize()).isEqualTo(actual.get(i).getIndexingSearchListSize());
if (Objects.equals(actual.get(i).getType(), CosmosVectorIndexType.DISK_ANN.toString())) {
assertThat(actual.get(i).getIndexingSearchListSize()).isEqualTo(expected.get(i).getIndexingSearchListSize());
}

}
}

private List<CosmosVectorIndexSpec> populateVectorIndexes() {
CosmosVectorIndexSpec cosmosVectorIndexSpec1 = new CosmosVectorIndexSpec();
cosmosVectorIndexSpec1.setPath("/vector1");
cosmosVectorIndexSpec1.setType(CosmosVectorIndexType.FLAT.toString());

CosmosVectorIndexSpec cosmosVectorIndexSpec2 = new CosmosVectorIndexSpec();
cosmosVectorIndexSpec2.setPath("/vector2");
cosmosVectorIndexSpec2.setType(CosmosVectorIndexType.QUANTIZED_FLAT.toString());
cosmosVectorIndexSpec2.setQuantizationSizeInBytes(2);
cosmosVectorIndexSpec2.setVectorIndexShardKeys(Arrays.asList("/zipCode"));

CosmosVectorIndexSpec cosmosVectorIndexSpec3 = new CosmosVectorIndexSpec();
cosmosVectorIndexSpec3.setPath("/vector3");
cosmosVectorIndexSpec3.setType(CosmosVectorIndexType.DISK_ANN.toString());
cosmosVectorIndexSpec3.setQuantizationSizeInBytes(2);
cosmosVectorIndexSpec3.setIndexingSearchListSize(30);
cosmosVectorIndexSpec3.setVectorIndexShardKeys(Arrays.asList("/country/city"));

return Arrays.asList(cosmosVectorIndexSpec1, cosmosVectorIndexSpec2, cosmosVectorIndexSpec3);
CosmosVectorIndexSpec cosmosVectorIndexSpec1 = new CosmosVectorIndexSpec()
.setPath("/vector1")
.setType(CosmosVectorIndexType.FLAT.toString());

CosmosVectorIndexSpec cosmosVectorIndexSpec2 = new CosmosVectorIndexSpec()
.setPath("/vector2")
.setType(CosmosVectorIndexType.QUANTIZED_FLAT.toString())
.setQuantizerType(QuantizerType.PRODUCT)
.setQuantizationSizeInBytes(2)
.setVectorIndexShardKeys(Arrays.asList("/zipCode"));

CosmosVectorIndexSpec cosmosVectorIndexSpec3 = new CosmosVectorIndexSpec()
.setPath("/vector3")
.setType(CosmosVectorIndexType.DISK_ANN.toString())
.setQuantizerType(QuantizerType.PRODUCT)
.setQuantizationSizeInBytes(2)
.setIndexingSearchListSize(30)
.setVectorIndexShardKeys(Arrays.asList("/country/city"));

CosmosVectorIndexSpec cosmosVectorIndexSpec4 = new CosmosVectorIndexSpec()
.setPath("/vector4")
.setType(CosmosVectorIndexType.DISK_ANN.toString())
.setQuantizerType(QuantizerType.SPHERICAL)
.setIndexingSearchListSize(30)
.setVectorIndexShardKeys(Arrays.asList("/country/city"));

return Arrays.asList(cosmosVectorIndexSpec1, cosmosVectorIndexSpec2, cosmosVectorIndexSpec3, cosmosVectorIndexSpec4);
}

private List<CosmosVectorEmbedding> populateEmbeddings() {
CosmosVectorEmbedding embedding1 = new CosmosVectorEmbedding();
embedding1.setPath("/vector1");
embedding1.setDataType(CosmosVectorDataType.INT8);
embedding1.setEmbeddingDimensions(3);
embedding1.setDistanceFunction(CosmosVectorDistanceFunction.COSINE);

CosmosVectorEmbedding embedding2 = new CosmosVectorEmbedding();
embedding2.setPath("/vector2");
embedding2.setDataType(CosmosVectorDataType.FLOAT32);
embedding2.setEmbeddingDimensions(3);
embedding2.setDistanceFunction(CosmosVectorDistanceFunction.DOT_PRODUCT);

CosmosVectorEmbedding embedding3 = new CosmosVectorEmbedding();
embedding3.setPath("/vector3");
embedding3.setDataType(CosmosVectorDataType.UINT8);
embedding3.setEmbeddingDimensions(3);
embedding3.setDistanceFunction(CosmosVectorDistanceFunction.EUCLIDEAN);
return Arrays.asList(embedding1, embedding2, embedding3);
CosmosVectorEmbedding embedding1 = new CosmosVectorEmbedding()
.setPath("/vector1")
.setDataType(CosmosVectorDataType.INT8)
.setEmbeddingDimensions(3)
.setDistanceFunction(CosmosVectorDistanceFunction.COSINE);

CosmosVectorEmbedding embedding2 = new CosmosVectorEmbedding()
.setPath("/vector2")
.setDataType(CosmosVectorDataType.FLOAT32)
.setEmbeddingDimensions(3)
.setDistanceFunction(CosmosVectorDistanceFunction.DOT_PRODUCT);

CosmosVectorEmbedding embedding3 = new CosmosVectorEmbedding()
.setPath("/vector3")
.setDataType(CosmosVectorDataType.UINT8)
.setEmbeddingDimensions(3)
.setDistanceFunction(CosmosVectorDistanceFunction.EUCLIDEAN);

CosmosVectorEmbedding embedding4 = new CosmosVectorEmbedding()
.setPath("/vector4")
.setDataType(CosmosVectorDataType.UINT8)
.setEmbeddingDimensions(3)
.setDistanceFunction(CosmosVectorDistanceFunction.EUCLIDEAN);

return Arrays.asList(embedding1, embedding2, embedding3, embedding4);
}

private String getVectorEmbeddingPolicyAsString() {
return "{\"vectorEmbeddings\":[" +
"{\"path\":\"/vector1\",\"dataType\":\"int8\",\"dimensions\":3,\"distanceFunction\":\"cosine\"}," +
"{\"path\":\"/vector2\",\"dataType\":\"float32\",\"dimensions\":3,\"distanceFunction\":\"dotproduct\"}," +
"{\"path\":\"/vector3\",\"dataType\":\"uint8\",\"dimensions\":3,\"distanceFunction\":\"euclidean\"}" +
"{\"path\":\"/vector3\",\"dataType\":\"uint8\",\"dimensions\":3,\"distanceFunction\":\"euclidean\"}," +
"{\"path\":\"/vector4\",\"dataType\":\"uint8\",\"dimensions\":3,\"distanceFunction\":\"euclidean\"}" +
"]}";
}

private String getVectorIndexesAsString() {
return "[" +
"{\"type\":\"flat\",\"path\":\"/vector1\"}," +
"{\"type\":\"quantizedFlat\",\"vectorIndexShardKeys\":[\"/zipCode\"],\"quantizerType\":\"product\",\"path\":\"/vector2\",\"quantizationByteSize\":2}," +
"{\"type\":\"diskANN\",\"indexingSearchListSize\":30,\"vectorIndexShardKeys\":[\"/country/city\"],\"quantizerType\":\"product\",\"path\":\"/vector3\",\"quantizationByteSize\":2}," +
"{\"type\":\"diskANN\",\"indexingSearchListSize\":30,\"vectorIndexShardKeys\":[\"/country/city\"],\"quantizerType\":\"spherical\",\"path\":\"/vector4\"}" +
"]";
}
}
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
### 4.77.0-beta.1 (Unreleased)

#### Features Added
* Added the `QuantizerType` to the vectorIndexSpec: `product`/`spherical`. - [PR 47566](https://github.com/Azure/azure-sdk-for-java/pull/47566)

#### Breaking Changes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

package com.azure.cosmos.implementation;

import java.util.List;

/**
* Used internally. Constants in the Azure Cosmos DB database service Java SDK.
Expand Down Expand Up @@ -121,6 +120,7 @@ public static final class Properties {
public static final String ORDER = "order";
public static final String SPATIAL_INDEXES = "spatialIndexes";
public static final String TYPES = "types";
public static final String QUANTIZER_TYPE = "quantizerType";

// Full text search
public static final String FULL_TEXT_INDEXES = "fullTextIndexes";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@
public enum IndexProperty {
INDEXING_SEARCH_LIST_SIZE,
QUANTIZATION_SIZE_IN_BYTES,
VECTOR_INDEX_SHARD_KEYS;
VECTOR_INDEX_SHARD_KEYS,
QUANTIZER_TYPE,
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ public final class CosmosVectorIndexSpec {
private Integer indexingSearchListSize;
@JsonInclude(JsonInclude.Include.NON_NULL)
private List<String> vectorIndexShardKeys;
@JsonInclude(JsonInclude.Include.NON_NULL)
private QuantizerType quantizerType;
private final JsonSerializable jsonSerializable;

/**
Expand Down Expand Up @@ -84,6 +86,35 @@ public CosmosVectorIndexSpec setType(String type) {
return this;
}

/**
* Gets the quantizer type.
*
* @return the quantizer type.
*/
public QuantizerType getQuantizerType() {
if (this.quantizerType == null) {
this.quantizerType = this.jsonSerializable.getObject(Constants.Properties.QUANTIZER_TYPE, QuantizerType.class);
}
return this.quantizerType;
}

/**
* Set the quantizer type.
*
* @param quantizerType The quantizer type
* @return the CosmosVectorIndexSpec.
*/
public CosmosVectorIndexSpec setQuantizerType(QuantizerType quantizerType) {
if (validateIndexType(IndexProperty.QUANTIZER_TYPE) && quantizerType != null) {
this.quantizerType = quantizerType;
this.jsonSerializable.set(Constants.Properties.QUANTIZER_TYPE, quantizerType);
} else {
this.quantizerType = null;
this.jsonSerializable.remove(Constants.Properties.QUANTIZER_TYPE);
}
return this;
}

/**
* Gets the quantization byte size
*
Expand Down Expand Up @@ -193,7 +224,9 @@ JsonSerializable getJsonSerializable() {

private Boolean validateIndexType(IndexProperty indexProperty) {
String vectorIndexType = this.jsonSerializable.getString(Constants.Properties.VECTOR_INDEX_TYPE);
if (indexProperty.equals(IndexProperty.QUANTIZATION_SIZE_IN_BYTES) || (indexProperty.equals(IndexProperty.VECTOR_INDEX_SHARD_KEYS))) {
if (indexProperty.equals(IndexProperty.QUANTIZATION_SIZE_IN_BYTES) ||
(indexProperty.equals(IndexProperty.VECTOR_INDEX_SHARD_KEYS)) ||
(indexProperty.equals(IndexProperty.QUANTIZER_TYPE))) {
return vectorIndexType.equals(CosmosVectorIndexType.QUANTIZED_FLAT.toString()) ||
vectorIndexType.equals(CosmosVectorIndexType.DISK_ANN.toString());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.azure.cosmos.models;

import com.fasterxml.jackson.annotation.JsonValue;

/**
* Defines quantizer types for vector index specifications in the Azure Cosmos DB service.
*/
public enum QuantizerType {
/**
* Represents a product quantizer type.
*/
PRODUCT("product"),

/**
* Represents a spherical quantizer type.
*/
SPHERICAL("spherical");


QuantizerType(String overWireValue) {
this.overWireValue = overWireValue;
}

private final String overWireValue;

@JsonValue
@Override
public String toString() {
return this.overWireValue;
}
}