diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/VectorIndexTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/VectorIndexTest.java index 60429a598b1a..70de3fc15978 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/VectorIndexTest.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/VectorIndexTest.java @@ -26,6 +26,7 @@ import com.azure.cosmos.models.PartitionKeyDefinition; import com.azure.cosmos.models.CosmosVectorIndexSpec; import com.azure.cosmos.models.CosmosVectorIndexType; +import com.azure.cosmos.models.QuantizerType; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import org.slf4j.Logger; @@ -79,29 +80,26 @@ public void afterClass() { @Test(groups = {"emulator"}, timeOut = TIMEOUT*10000) public void shouldCreateVectorEmbeddingPolicy() { - PartitionKeyDefinition partitionKeyDef = new PartitionKeyDefinition(); - ArrayList paths = new ArrayList(); - paths.add("/mypk"); - partitionKeyDef.setPaths(paths); - - CosmosContainerProperties collectionDefinition = new CosmosContainerProperties(UUID.randomUUID().toString(), partitionKeyDef); + ArrayList paths = new ArrayList<>(Arrays.asList("/mypk")); + PartitionKeyDefinition partitionKeyDef = new PartitionKeyDefinition() + .setPaths(paths); - IndexingPolicy indexingPolicy = new IndexingPolicy(); - indexingPolicy.setIndexingMode(IndexingMode.CONSISTENT); ExcludedPath excludedPath = new ExcludedPath("/*"); - indexingPolicy.setExcludedPaths(Collections.singletonList(excludedPath)); - IncludedPath includedPath1 = new IncludedPath("/name/?"); IncludedPath includedPath2 = new IncludedPath("/description/?"); - indexingPolicy.setIncludedPaths(ImmutableList.of(includedPath1, includedPath2)); - indexingPolicy.setVectorIndexes(populateVectorIndexes()); + IndexingPolicy indexingPolicy = new IndexingPolicy() + .setIndexingMode(IndexingMode.CONSISTENT) + .setExcludedPaths(Collections.singletonList(excludedPath)) + .setIncludedPaths(ImmutableList.of(includedPath1, includedPath2)) + .setVectorIndexes(populateVectorIndexes()); CosmosVectorEmbeddingPolicy cosmosVectorEmbeddingPolicy = new CosmosVectorEmbeddingPolicy(); cosmosVectorEmbeddingPolicy.setCosmosVectorEmbeddings(populateEmbeddings()); - collectionDefinition.setIndexingPolicy(indexingPolicy); - collectionDefinition.setVectorEmbeddingPolicy(cosmosVectorEmbeddingPolicy); + CosmosContainerProperties collectionDefinition = new CosmosContainerProperties(UUID.randomUUID().toString(), partitionKeyDef) + .setIndexingPolicy(indexingPolicy) + .setVectorEmbeddingPolicy(cosmosVectorEmbeddingPolicy); database.createContainer(collectionDefinition).block(); CosmosAsyncContainer createdCollection = database.getContainer(collectionDefinition.getId()); @@ -279,6 +277,22 @@ public void shouldValidateVectorEmbeddingPolicySerializationAndDeserialization() validateVectorEmbeddingPolicy(cosmosVectorEmbeddingPolicy, expectedCosmosVectorEmbeddingPolicy); } + @Test(groups = {"unit"}, timeOut = TIMEOUT) + public void shouldValidateVectorIndexesSerializationAndDeserialization() throws JsonProcessingException { + IndexingPolicy indexingPolicy = new IndexingPolicy(); + indexingPolicy.setVectorIndexes(populateVectorIndexes()); + List expectedVectorIndexes = indexingPolicy.getVectorIndexes(); + + // Validate Vector Indexes Serialization + String actualVectorIndexesJSON = simpleObjectMapper.writeValueAsString(expectedVectorIndexes); + String expectedVectorIndexesJSON = getVectorIndexesAsString(); + assertThat(actualVectorIndexesJSON).isEqualTo(expectedVectorIndexesJSON); + + // Validate Vector Indexes Deserialization + List actualVectorIndexes = Arrays.asList(simpleObjectMapper.readValue(actualVectorIndexesJSON, CosmosVectorIndexSpec[].class)); + validateVectorIndexes(actualVectorIndexes, expectedVectorIndexes); + } + private void validateCollectionProperties(CosmosContainerProperties collectionDefinition, CosmosContainerProperties collectionProperties) { assertThat(collectionProperties.getVectorEmbeddingPolicy()).isNotNull(); assertThat(collectionProperties.getVectorEmbeddingPolicy().getVectorEmbeddings()).isNotNull(); @@ -286,7 +300,7 @@ private void validateCollectionProperties(CosmosContainerProperties collectionDe collectionDefinition.getVectorEmbeddingPolicy()); assertThat(collectionProperties.getIndexingPolicy().getVectorIndexes()).isNotNull(); - validateVectorIndexes(collectionDefinition.getIndexingPolicy().getVectorIndexes(), collectionProperties.getIndexingPolicy().getVectorIndexes()); + validateVectorIndexes(collectionProperties.getIndexingPolicy().getVectorIndexes(), collectionDefinition.getIndexingPolicy().getVectorIndexes()); } private void validateVectorEmbeddingPolicy(CosmosVectorEmbeddingPolicy actual, CosmosVectorEmbeddingPolicy expected) { @@ -302,74 +316,91 @@ private void validateVectorEmbeddingPolicy(CosmosVectorEmbeddingPolicy actual, C } private void validateVectorIndexes(List actual, List expected) { - assertThat(expected).hasSameSizeAs(actual); - for (int i = 0; i < expected.size(); i++) { - assertThat(expected.get(i).getPath()).isEqualTo(actual.get(i).getPath()); - assertThat(expected.get(i).getType()).isEqualTo(actual.get(i).getType()); - if (Objects.equals(expected.get(i).getType(), CosmosVectorIndexType.QUANTIZED_FLAT.toString()) || - Objects.equals(expected.get(i).getType(), CosmosVectorIndexType.DISK_ANN.toString())) { - assertThat(expected.get(i).getQuantizationSizeInBytes()).isEqualTo(actual.get(i).getQuantizationSizeInBytes()); - assertThat(expected.get(i).getVectorIndexShardKeys()).isEqualTo(actual.get(i).getVectorIndexShardKeys()); + assertThat(actual).hasSameSizeAs(expected); + for (int i = 0; i < actual.size(); i++) { + assertThat(actual.get(i).getPath()).isEqualTo(expected.get(i).getPath()); + assertThat(actual.get(i).getType()).isEqualTo(expected.get(i).getType()); + if (Objects.equals(actual.get(i).getType(), CosmosVectorIndexType.QUANTIZED_FLAT.toString()) || + Objects.equals(actual.get(i).getType(), CosmosVectorIndexType.DISK_ANN.toString())) { + assertThat(actual.get(i).getQuantizerType()).isEqualTo(expected.get(i).getQuantizerType()); + assertThat(actual.get(i).getQuantizationSizeInBytes()).isEqualTo(expected.get(i).getQuantizationSizeInBytes()); + assertThat(actual.get(i).getVectorIndexShardKeys()).isEqualTo(expected.get(i).getVectorIndexShardKeys()); } - if (Objects.equals(expected.get(i).getType(), CosmosVectorIndexType.DISK_ANN.toString())) { - assertThat(expected.get(i).getIndexingSearchListSize()).isEqualTo(actual.get(i).getIndexingSearchListSize()); + if (Objects.equals(actual.get(i).getType(), CosmosVectorIndexType.DISK_ANN.toString())) { + assertThat(actual.get(i).getIndexingSearchListSize()).isEqualTo(expected.get(i).getIndexingSearchListSize()); } } } private List populateVectorIndexes() { - CosmosVectorIndexSpec cosmosVectorIndexSpec1 = new CosmosVectorIndexSpec(); - cosmosVectorIndexSpec1.setPath("/vector1"); - cosmosVectorIndexSpec1.setType(CosmosVectorIndexType.FLAT.toString()); - - CosmosVectorIndexSpec cosmosVectorIndexSpec2 = new CosmosVectorIndexSpec(); - cosmosVectorIndexSpec2.setPath("/vector2"); - cosmosVectorIndexSpec2.setType(CosmosVectorIndexType.QUANTIZED_FLAT.toString()); - cosmosVectorIndexSpec2.setQuantizationSizeInBytes(2); - cosmosVectorIndexSpec2.setVectorIndexShardKeys(Arrays.asList("/zipCode")); - - CosmosVectorIndexSpec cosmosVectorIndexSpec3 = new CosmosVectorIndexSpec(); - cosmosVectorIndexSpec3.setPath("/vector3"); - cosmosVectorIndexSpec3.setType(CosmosVectorIndexType.DISK_ANN.toString()); - cosmosVectorIndexSpec3.setQuantizationSizeInBytes(2); - cosmosVectorIndexSpec3.setIndexingSearchListSize(30); - cosmosVectorIndexSpec3.setVectorIndexShardKeys(Arrays.asList("/country/city")); - - CosmosVectorIndexSpec cosmosVectorIndexSpec4 = new CosmosVectorIndexSpec(); - cosmosVectorIndexSpec4.setPath("/vector4"); - cosmosVectorIndexSpec4.setType(CosmosVectorIndexType.QUANTIZED_FLAT.toString()); - cosmosVectorIndexSpec4.setQuantizationSizeInBytes(2); - cosmosVectorIndexSpec4.setVectorIndexShardKeys(Arrays.asList("/zipCode")); - - return Arrays.asList(cosmosVectorIndexSpec1, cosmosVectorIndexSpec2, cosmosVectorIndexSpec3, cosmosVectorIndexSpec4); + CosmosVectorIndexSpec cosmosVectorIndexSpec1 = new CosmosVectorIndexSpec() + .setPath("/vector1") + .setType(CosmosVectorIndexType.FLAT.toString()); + + CosmosVectorIndexSpec cosmosVectorIndexSpec2 = new CosmosVectorIndexSpec() + .setPath("/vector2") + .setType(CosmosVectorIndexType.QUANTIZED_FLAT.toString()) + .setQuantizerType(QuantizerType.PRODUCT) + .setQuantizationSizeInBytes(2) + .setVectorIndexShardKeys(Arrays.asList("/zipCode")); + + CosmosVectorIndexSpec cosmosVectorIndexSpec3 = new CosmosVectorIndexSpec() + .setPath("/vector3") + .setType(CosmosVectorIndexType.DISK_ANN.toString()) + .setQuantizerType(QuantizerType.PRODUCT) + .setQuantizationSizeInBytes(2) + .setIndexingSearchListSize(30) + .setVectorIndexShardKeys(Arrays.asList("/country/city")); + + CosmosVectorIndexSpec cosmosVectorIndexSpec4 = new CosmosVectorIndexSpec() + .setPath("/vector4") + .setType(CosmosVectorIndexType.QUANTIZED_FLAT.toString()) + .setQuantizerType(QuantizerType.PRODUCT) + .setQuantizationSizeInBytes(2) + .setVectorIndexShardKeys(Arrays.asList("/zipCode")); + + CosmosVectorIndexSpec cosmosVectorIndexSpec5 = new CosmosVectorIndexSpec() + .setPath("/vector5") + .setType(CosmosVectorIndexType.DISK_ANN.toString()) + .setQuantizerType(QuantizerType.SPHERICAL) + .setIndexingSearchListSize(30); + + return Arrays.asList(cosmosVectorIndexSpec1, cosmosVectorIndexSpec2, cosmosVectorIndexSpec3, cosmosVectorIndexSpec4, cosmosVectorIndexSpec5); } private List populateEmbeddings() { - CosmosVectorEmbedding embedding1 = new CosmosVectorEmbedding(); - embedding1.setPath("/vector1"); - embedding1.setDataType(CosmosVectorDataType.INT8); - embedding1.setEmbeddingDimensions(3); - embedding1.setDistanceFunction(CosmosVectorDistanceFunction.COSINE); - - CosmosVectorEmbedding embedding2 = new CosmosVectorEmbedding(); - embedding2.setPath("/vector2"); - embedding2.setDataType(CosmosVectorDataType.FLOAT32); - embedding2.setEmbeddingDimensions(3); - embedding2.setDistanceFunction(CosmosVectorDistanceFunction.DOT_PRODUCT); - - CosmosVectorEmbedding embedding3 = new CosmosVectorEmbedding(); - embedding3.setPath("/vector3"); - embedding3.setDataType(CosmosVectorDataType.UINT8); - embedding3.setEmbeddingDimensions(3); - embedding3.setDistanceFunction(CosmosVectorDistanceFunction.EUCLIDEAN); - - CosmosVectorEmbedding embedding4 = new CosmosVectorEmbedding(); - embedding4.setPath("/vector4"); - embedding4.setDataType(CosmosVectorDataType.FLOAT16); - embedding4.setEmbeddingDimensions(3); - embedding4.setDistanceFunction(CosmosVectorDistanceFunction.DOT_PRODUCT); - return Arrays.asList(embedding1, embedding2, embedding3, embedding4); + CosmosVectorEmbedding embedding1 = new CosmosVectorEmbedding() + .setPath("/vector1") + .setDataType(CosmosVectorDataType.INT8) + .setEmbeddingDimensions(3) + .setDistanceFunction(CosmosVectorDistanceFunction.COSINE); + + CosmosVectorEmbedding embedding2 = new CosmosVectorEmbedding() + .setPath("/vector2") + .setDataType(CosmosVectorDataType.FLOAT32) + .setEmbeddingDimensions(3) + .setDistanceFunction(CosmosVectorDistanceFunction.DOT_PRODUCT); + + CosmosVectorEmbedding embedding3 = new CosmosVectorEmbedding() + .setPath("/vector3") + .setDataType(CosmosVectorDataType.UINT8) + .setEmbeddingDimensions(3) + .setDistanceFunction(CosmosVectorDistanceFunction.EUCLIDEAN); + + CosmosVectorEmbedding embedding4 = new CosmosVectorEmbedding() + .setPath("/vector4") + .setDataType(CosmosVectorDataType.FLOAT16) + .setEmbeddingDimensions(3) + .setDistanceFunction(CosmosVectorDistanceFunction.DOT_PRODUCT); + + CosmosVectorEmbedding embedding5 = new CosmosVectorEmbedding() + .setPath("/vector5") + .setDataType(CosmosVectorDataType.UINT8) + .setEmbeddingDimensions(3) + .setDistanceFunction(CosmosVectorDistanceFunction.EUCLIDEAN); + + return Arrays.asList(embedding1, embedding2, embedding3, embedding4, embedding5); } private String getVectorEmbeddingPolicyAsString() { @@ -377,7 +408,17 @@ private String getVectorEmbeddingPolicyAsString() { "{\"path\":\"/vector1\",\"dataType\":\"int8\",\"dimensions\":3,\"distanceFunction\":\"cosine\"}," + "{\"path\":\"/vector2\",\"dataType\":\"float32\",\"dimensions\":3,\"distanceFunction\":\"dotproduct\"}," + "{\"path\":\"/vector3\",\"dataType\":\"uint8\",\"dimensions\":3,\"distanceFunction\":\"euclidean\"}," + - "{\"path\":\"/vector4\",\"dataType\":\"float16\",\"dimensions\":3,\"distanceFunction\":\"dotproduct\"}" + + "{\"path\":\"/vector4\",\"dataType\":\"float16\",\"dimensions\":3,\"distanceFunction\":\"dotproduct\"}," + + "{\"path\":\"/vector5\",\"dataType\":\"uint8\",\"dimensions\":3,\"distanceFunction\":\"euclidean\"}" + "]}"; } + + private String getVectorIndexesAsString() { + return "[" + + "{\"type\":\"flat\",\"path\":\"/vector1\"}," + + "{\"type\":\"quantizedFlat\",\"vectorIndexShardKeys\":[\"/zipCode\"],\"quantizerType\":\"product\",\"path\":\"/vector2\",\"quantizationByteSize\":2}," + + "{\"type\":\"diskANN\",\"indexingSearchListSize\":30,\"vectorIndexShardKeys\":[\"/country/city\"],\"quantizerType\":\"product\",\"path\":\"/vector3\",\"quantizationByteSize\":2}," + + "{\"type\":\"diskANN\",\"indexingSearchListSize\":30,\"vectorIndexShardKeys\":[\"/country/city\"],\"quantizerType\":\"spherical\",\"path\":\"/vector4\"}" + + "]"; + } } diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index ce3b76a5d722..628aa8b578c7 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -3,6 +3,7 @@ ### 4.77.0-beta.1 (Unreleased) #### Features Added +* Added the `QuantizerType` to the vectorIndexSpec: `product`/`spherical`. - [PR 47566](https://github.com/Azure/azure-sdk-for-java/pull/47566) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/Constants.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/Constants.java index 3ea927bf9b72..42d1d4a40bf2 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/Constants.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/Constants.java @@ -3,7 +3,6 @@ package com.azure.cosmos.implementation; -import java.util.List; /** * Used internally. Constants in the Azure Cosmos DB database service Java SDK. @@ -121,6 +120,7 @@ public static final class Properties { public static final String ORDER = "order"; public static final String SPATIAL_INDEXES = "spatialIndexes"; public static final String TYPES = "types"; + public static final String QUANTIZER_TYPE = "quantizerType"; // Full text search public static final String FULL_TEXT_INDEXES = "fullTextIndexes"; diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/query/IndexProperty.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/query/IndexProperty.java index e595d2212d80..1e1a4c1cca3d 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/query/IndexProperty.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/query/IndexProperty.java @@ -5,5 +5,6 @@ public enum IndexProperty { INDEXING_SEARCH_LIST_SIZE, QUANTIZATION_SIZE_IN_BYTES, - VECTOR_INDEX_SHARD_KEYS; + VECTOR_INDEX_SHARD_KEYS, + QUANTIZER_TYPE, } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosVectorIndexSpec.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosVectorIndexSpec.java index 42f1ba663c16..870897595438 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosVectorIndexSpec.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosVectorIndexSpec.java @@ -26,6 +26,8 @@ public final class CosmosVectorIndexSpec { private Integer indexingSearchListSize; @JsonInclude(JsonInclude.Include.NON_NULL) private List vectorIndexShardKeys; + @JsonInclude(JsonInclude.Include.NON_NULL) + private QuantizerType quantizerType; private final JsonSerializable jsonSerializable; /** @@ -84,6 +86,35 @@ public CosmosVectorIndexSpec setType(String type) { return this; } + /** + * Gets the quantizer type. + * + * @return the quantizer type. + */ + public QuantizerType getQuantizerType() { + if (this.quantizerType == null) { + this.quantizerType = this.jsonSerializable.getObject(Constants.Properties.QUANTIZER_TYPE, QuantizerType.class); + } + return this.quantizerType; + } + + /** + * Set the quantizer type. + * + * @param quantizerType The quantizer type + * @return the CosmosVectorIndexSpec. + */ + public CosmosVectorIndexSpec setQuantizerType(QuantizerType quantizerType) { + if (validateIndexType(IndexProperty.QUANTIZER_TYPE) && quantizerType != null) { + this.quantizerType = quantizerType; + this.jsonSerializable.set(Constants.Properties.QUANTIZER_TYPE, quantizerType); + } else { + this.quantizerType = null; + this.jsonSerializable.remove(Constants.Properties.QUANTIZER_TYPE); + } + return this; + } + /** * Gets the quantization byte size * @@ -193,7 +224,9 @@ JsonSerializable getJsonSerializable() { private Boolean validateIndexType(IndexProperty indexProperty) { String vectorIndexType = this.jsonSerializable.getString(Constants.Properties.VECTOR_INDEX_TYPE); - if (indexProperty.equals(IndexProperty.QUANTIZATION_SIZE_IN_BYTES) || (indexProperty.equals(IndexProperty.VECTOR_INDEX_SHARD_KEYS))) { + if (indexProperty.equals(IndexProperty.QUANTIZATION_SIZE_IN_BYTES) || + (indexProperty.equals(IndexProperty.VECTOR_INDEX_SHARD_KEYS)) || + (indexProperty.equals(IndexProperty.QUANTIZER_TYPE))) { return vectorIndexType.equals(CosmosVectorIndexType.QUANTIZED_FLAT.toString()) || vectorIndexType.equals(CosmosVectorIndexType.DISK_ANN.toString()); } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/QuantizerType.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/QuantizerType.java new file mode 100644 index 000000000000..03ca19ea893b --- /dev/null +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/QuantizerType.java @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.cosmos.models; + +import com.fasterxml.jackson.annotation.JsonValue; + +/** + * Defines quantizer types for vector index specifications in the Azure Cosmos DB service. + */ +public enum QuantizerType { + /** + * Represents a product quantizer type. + */ + PRODUCT("product"), + + /** + * Represents a spherical quantizer type. + */ + SPHERICAL("spherical"); + + + QuantizerType(String overWireValue) { + this.overWireValue = overWireValue; + } + + private final String overWireValue; + + @JsonValue + @Override + public String toString() { + return this.overWireValue; + } +} +