Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bugfix: update Azure AI Search Config #2380

Merged
merged 8 commits into from
Mar 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 35 additions & 9 deletions mem0/configs/vector_stores/azure_ai_search.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,53 @@
from typing import Any, Dict

from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, model_validator


class AzureAISearchConfig(BaseModel):
collection_name: str = Field("mem0", description="Name of the collection")
service_name: str = Field(None, description="Azure Cognitive Search service name")
api_key: str = Field(None, description="API key for the Azure Cognitive Search service")
service_name: str = Field(None, description="Azure AI Search service name")
api_key: str = Field(None, description="API key for the Azure AI Search service")
embedding_model_dims: int = Field(None, description="Dimension of the embedding vector")
use_compression: bool = Field(False, description="Whether to use scalar quantization vector compression.")

compression_type: Optional[str] = Field(
None,
description="Type of vector compression to use. Options: 'scalar', 'binary', or None"
)
use_float16: bool = Field(
False,
description="Whether to store vectors in half precision (Edm.Half) instead of full precision (Edm.Single)"
)

@model_validator(mode="before")
@classmethod
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
allowed_fields = set(cls.model_fields.keys())
input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields

# Check for use_compression to provide a helpful error
if "use_compression" in extra_fields:
raise ValueError(
"The parameter 'use_compression' is no longer supported. "
"Please use 'compression_type=\"scalar\"' instead of 'use_compression=True' "
"or 'compression_type=None' instead of 'use_compression=False'."
)

if extra_fields:
raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
f"Extra fields not allowed: {', '.join(extra_fields)}. "
f"Please input only the following fields: {', '.join(allowed_fields)}"
)

# Validate compression_type values
if "compression_type" in values and values["compression_type"] is not None:
valid_types = ["scalar", "binary"]
if values["compression_type"].lower() not in valid_types:
raise ValueError(
f"Invalid compression_type: {values['compression_type']}. "
f"Must be one of: {', '.join(valid_types)}, or None"
)

return values

model_config = {
"arbitrary_types_allowed": True,
}
}
Loading