Skip to content

Commit

Permalink
adds 2.5.3 bm25 support while keeping previous support for milvus lite (
Browse files Browse the repository at this point in the history
  • Loading branch information
jperez999 authored Jan 30, 2025
1 parent 115a256 commit 7cf2a4d
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 12 deletions.
60 changes: 49 additions & 11 deletions client/src/nv_ingest_client/util/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
DataType,
CollectionSchema,
connections,
Function,
FunctionType,
utility,
BulkInsertState,
AnnSearchRequest,
Expand Down Expand Up @@ -103,7 +105,7 @@ def run(self, records):
raise ValueError(f"Unsupported type for collection_name detected: {type(collection_name)}")


def create_nvingest_schema(dense_dim: int = 1024, sparse: bool = False) -> CollectionSchema:
def create_nvingest_schema(dense_dim: int = 1024, sparse: bool = False, local_index: bool = False) -> CollectionSchema:
"""
Creates a schema for the nv-ingest produced data. This is currently setup to follow
the default expected schema fields in nv-ingest. You can see more about the declared fields
Expand All @@ -127,13 +129,32 @@ def create_nvingest_schema(dense_dim: int = 1024, sparse: bool = False) -> Colle
"""
schema = MilvusClient.create_schema(auto_id=True, enable_dynamic_field=True)
schema.add_field(field_name="pk", datatype=DataType.INT64, is_primary=True, auto_id=True)
schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=65535)
schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=dense_dim)
schema.add_field(field_name="source", datatype=DataType.JSON)
schema.add_field(field_name="content_metadata", datatype=DataType.JSON)
if sparse:
if sparse and local_index:
schema.add_field(field_name="sparse", datatype=DataType.SPARSE_FLOAT_VECTOR)
elif sparse:
schema.add_field(field_name="sparse", datatype=DataType.SPARSE_FLOAT_VECTOR)
schema.add_field(
field_name="text",
datatype=DataType.VARCHAR,
max_length=65535,
enable_analyzer=True,
analyzer_params={"type": "english"},
enable_match=True,
)
schema.add_function(
Function(
name="bm25",
function_type=FunctionType.BM25,
input_field_names=["text"],
output_field_names="sparse",
)
)

else:
schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=65535)
return schema


Expand Down Expand Up @@ -192,14 +213,20 @@ def create_nvingest_index_params(
metric_type="L2",
params={"M": 64, "efConstruction": 512},
)
if sparse:
if sparse and local_index:
index_params.add_index(
field_name="sparse",
index_name="sparse_index",
index_type="SPARSE_INVERTED_INDEX", # Index type for sparse vectors
metric_type="IP", # Currently, only IP (Inner Product) is supported for sparse vectors
params={"drop_ratio_build": 0.2}, # The ratio of small vector values to be dropped during indexing
)
elif sparse:
index_params.add_index(
field_name="sparse",
index_type="SPARSE_INVERTED_INDEX",
metric_type="BM25",
)
return index_params


Expand Down Expand Up @@ -283,7 +310,7 @@ def create_nvingest_collection(
local_index = True

client = MilvusClient(milvus_uri)
schema = create_nvingest_schema(dense_dim=dense_dim, sparse=sparse)
schema = create_nvingest_schema(dense_dim=dense_dim, sparse=sparse, local_index=local_index)
index_params = create_nvingest_index_params(
sparse=sparse, gpu_index=gpu_index, gpu_search=gpu_search, local_index=local_index
)
Expand Down Expand Up @@ -558,15 +585,18 @@ def write_to_nvingest_collection(
Minio bucket name.
"""
stream = False
local_index = False
connections.connect(uri=milvus_uri)
if urlparse(milvus_uri).scheme:
server_version = utility.get_server_version()
if "lite" in server_version:
stream = True
else:
stream = True
if milvus_uri.endswith(".db"):
local_index = True
bm25_ef = None
if sparse and compute_bm25_stats:
if local_index and sparse and compute_bm25_stats:
bm25_ef = create_bm25_model(
records,
enable_text=enable_text,
Expand All @@ -575,7 +605,7 @@ def write_to_nvingest_collection(
enable_images=enable_images,
)
bm25_ef.save(bm25_save_path)
elif sparse and not compute_bm25_stats:
elif local_index and sparse:
bm25_ef = BM25EmbeddingFunction(build_default_analyzer(language="en"))
bm25_ef.load(bm25_save_path)
client = MilvusClient(milvus_uri)
Expand Down Expand Up @@ -716,7 +746,10 @@ def hybrid_retrieval(
sparse_embeddings = []
for query in queries:
dense_embeddings.append(dense_model.get_query_embedding(query))
sparse_embeddings.append(_format_sparse_embedding(sparse_model.encode_queries([query])))
if sparse_model:
sparse_embeddings.append(_format_sparse_embedding(sparse_model.encode_queries([query])))
else:
sparse_embeddings.append(query)

s_param_1 = {
"metric_type": "L2",
Expand All @@ -733,11 +766,14 @@ def hybrid_retrieval(
}

dense_req = AnnSearchRequest(**search_param_1)
s_param_2 = {"metric_type": "BM25"}
if local_index:
s_param_2 = {"metric_type": "IP", "params": {"drop_ratio_build": 0.0}}

search_param_2 = {
"data": sparse_embeddings,
"anns_field": sparse_field,
"param": {"metric_type": "IP", "params": {"drop_ratio_build": 0.2}},
"param": s_param_2,
"limit": top_k * 2,
}
sparse_req = AnnSearchRequest(**search_param_2)
Expand Down Expand Up @@ -804,8 +840,10 @@ def nvingest_retrieval(
if milvus_uri.endswith(".db"):
local_index = True
if hybrid:
bm25_ef = BM25EmbeddingFunction(build_default_analyzer(language="en"))
bm25_ef.load(sparse_model_filepath)
bm25_ef = None
if local_index:
bm25_ef = BM25EmbeddingFunction(build_default_analyzer(language="en"))
bm25_ef.load(sparse_model_filepath)
results = hybrid_retrieval(
queries,
collection_name,
Expand Down
2 changes: 1 addition & 1 deletion docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ services:
# Turn on to leverage the `vdb_upload` task
restart: always
container_name: milvus-standalone
image: milvusdb/milvus:v2.4.17-gpu
image: milvusdb/milvus:v2.5.3-gpu
command: [ "milvus", "run", "standalone" ]
hostname: milvus
security_opt:
Expand Down

0 comments on commit 7cf2a4d

Please sign in to comment.