Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adds 2.5.3 bm25 support while keeping previous support for milvus lite #347

Merged
merged 2 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 49 additions & 11 deletions client/src/nv_ingest_client/util/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
DataType,
CollectionSchema,
connections,
Function,
FunctionType,
utility,
BulkInsertState,
AnnSearchRequest,
Expand Down Expand Up @@ -101,7 +103,7 @@ def run(self, records):
raise ValueError(f"Unsupported type for collection_name detected: {type(collection_name)}")


def create_nvingest_schema(dense_dim: int = 1024, sparse: bool = False) -> CollectionSchema:
def create_nvingest_schema(dense_dim: int = 1024, sparse: bool = False, local_index: bool = False) -> CollectionSchema:
"""
Creates a schema for the nv-ingest produced data. This is currently setup to follow
the default expected schema fields in nv-ingest. You can see more about the declared fields
Expand All @@ -125,13 +127,32 @@ def create_nvingest_schema(dense_dim: int = 1024, sparse: bool = False) -> Colle
"""
schema = MilvusClient.create_schema(auto_id=True, enable_dynamic_field=True)
schema.add_field(field_name="pk", datatype=DataType.INT64, is_primary=True, auto_id=True)
schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=65535)
schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=dense_dim)
schema.add_field(field_name="source", datatype=DataType.JSON)
schema.add_field(field_name="content_metadata", datatype=DataType.JSON)
if sparse:
if sparse and local_index:
schema.add_field(field_name="sparse", datatype=DataType.SPARSE_FLOAT_VECTOR)
elif sparse:
schema.add_field(field_name="sparse", datatype=DataType.SPARSE_FLOAT_VECTOR)
schema.add_field(
field_name="text",
datatype=DataType.VARCHAR,
max_length=65535,
enable_analyzer=True,
analyzer_params={"type": "english"},
enable_match=True,
)
schema.add_function(
Function(
name="bm25",
function_type=FunctionType.BM25,
input_field_names=["text"],
output_field_names="sparse",
)
)

else:
schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=65535)
return schema


Expand Down Expand Up @@ -190,14 +211,20 @@ def create_nvingest_index_params(
metric_type="L2",
params={"M": 64, "efConstruction": 512},
)
if sparse:
if sparse and local_index:
index_params.add_index(
field_name="sparse",
index_name="sparse_index",
index_type="SPARSE_INVERTED_INDEX", # Index type for sparse vectors
metric_type="IP", # Currently, only IP (Inner Product) is supported for sparse vectors
params={"drop_ratio_build": 0.2}, # The ratio of small vector values to be dropped during indexing
)
elif sparse:
index_params.add_index(
field_name="sparse",
index_type="SPARSE_INVERTED_INDEX",
metric_type="BM25",
)
return index_params


Expand Down Expand Up @@ -281,7 +308,7 @@ def create_nvingest_collection(
local_index = True

client = MilvusClient(milvus_uri)
schema = create_nvingest_schema(dense_dim=dense_dim, sparse=sparse)
schema = create_nvingest_schema(dense_dim=dense_dim, sparse=sparse, local_index=local_index)
index_params = create_nvingest_index_params(
sparse=sparse, gpu_index=gpu_index, gpu_search=gpu_search, local_index=local_index
)
Expand Down Expand Up @@ -539,20 +566,23 @@ def write_to_nvingest_collection(
Minio bucket name.
"""
stream = False
local_index = False
connections.connect(uri=milvus_uri)
if urlparse(milvus_uri).scheme:
server_version = utility.get_server_version()
if "lite" in server_version:
stream = True
else:
stream = True
if milvus_uri.endswith(".db"):
local_index = True
bm25_ef = None
if sparse and compute_bm25_stats:
if local_index and sparse and compute_bm25_stats:
bm25_ef = create_bm25_model(
records, enable_text=enable_text, enable_charts=enable_charts, enable_tables=enable_tables
)
bm25_ef.save(bm25_save_path)
elif sparse and not compute_bm25_stats:
elif local_index and sparse:
bm25_ef = BM25EmbeddingFunction(build_default_analyzer(language="en"))
bm25_ef.load(bm25_save_path)
client = MilvusClient(milvus_uri)
Expand Down Expand Up @@ -691,7 +721,10 @@ def hybrid_retrieval(
sparse_embeddings = []
for query in queries:
dense_embeddings.append(dense_model.get_query_embedding(query))
sparse_embeddings.append(_format_sparse_embedding(sparse_model.encode_queries([query])))
if sparse_model:
sparse_embeddings.append(_format_sparse_embedding(sparse_model.encode_queries([query])))
else:
sparse_embeddings.append(query)

s_param_1 = {
"metric_type": "L2",
Expand All @@ -708,11 +741,14 @@ def hybrid_retrieval(
}

dense_req = AnnSearchRequest(**search_param_1)
s_param_2 = {"metric_type": "BM25"}
if local_index:
s_param_2 = {"metric_type": "IP", "params": {"drop_ratio_build": 0.0}}

search_param_2 = {
"data": sparse_embeddings,
"anns_field": sparse_field,
"param": {"metric_type": "IP", "params": {"drop_ratio_build": 0.2}},
"param": s_param_2,
"limit": top_k * 2,
}
sparse_req = AnnSearchRequest(**search_param_2)
Expand Down Expand Up @@ -779,8 +815,10 @@ def nvingest_retrieval(
if milvus_uri.endswith(".db"):
local_index = True
if hybrid:
bm25_ef = BM25EmbeddingFunction(build_default_analyzer(language="en"))
bm25_ef.load(sparse_model_filepath)
bm25_ef = None
if local_index:
bm25_ef = BM25EmbeddingFunction(build_default_analyzer(language="en"))
bm25_ef.load(sparse_model_filepath)
results = hybrid_retrieval(
queries,
collection_name,
Expand Down
2 changes: 1 addition & 1 deletion docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ services:
# Turn on to leverage the `vdb_upload` task
restart: always
container_name: milvus-standalone
image: milvusdb/milvus:v2.4.17-gpu
image: milvusdb/milvus:v2.5.3-gpu
command: [ "milvus", "run", "standalone" ]
hostname: milvus
security_opt:
Expand Down
Loading