11"""Tests for operation utilities."""
22
33import os
4- from unittest .mock import Mock
4+ from unittest .mock import Mock , patch
55
66import pytest
77from bson import ObjectId
88from pymongo import MongoClient
99from pymongo .collection import Collection
1010
11- from pymongo_vectorsearch_utils .operation import bulk_embed_and_insert_texts
11+ from pymongo_vectorsearch_utils import drop_vector_search_index
12+ from pymongo_vectorsearch_utils .index import create_vector_search_index , wait_for_docs_in_index
13+ from pymongo_vectorsearch_utils .operation import bulk_embed_and_insert_texts , execute_search_query
1214
1315DB_NAME = "vectorsearch_utils_test"
1416COLLECTION_NAME = "test_operation"
17+ VECTOR_INDEX_NAME = "operation_vector_index"
1518
1619
1720@pytest .fixture (scope = "module" )
@@ -22,6 +25,17 @@ def client():
2225 client .close ()
2326
2427
28+ @pytest .fixture (scope = "module" )
29+ def preserved_collection (client ):
30+ if COLLECTION_NAME not in client [DB_NAME ].list_collection_names ():
31+ clxn = client [DB_NAME ].create_collection (COLLECTION_NAME )
32+ else :
33+ clxn = client [DB_NAME ][COLLECTION_NAME ]
34+ clxn .delete_many ({})
35+ yield clxn
36+ clxn .delete_many ({})
37+
38+
2539@pytest .fixture
2640def collection (client ):
2741 if COLLECTION_NAME not in client [DB_NAME ].list_collection_names ():
@@ -266,3 +280,182 @@ def test_custom_field_names(self, collection: Collection, mock_embedding_func):
266280 assert "vector" in doc
267281 assert doc ["content" ] == texts [0 ]
268282 assert doc ["vector" ] == [0.0 , 0.0 , 0.0 ]
283+
284+
285+ class TestExecuteSearchQuery :
286+ @pytest .fixture (scope = "class" , autouse = True )
287+ def vector_search_index (self , client ):
288+ coll = client [DB_NAME ][COLLECTION_NAME ]
289+ if len (coll .list_search_indexes (VECTOR_INDEX_NAME ).to_list ()) == 0 :
290+ create_vector_search_index (
291+ collection = coll ,
292+ index_name = VECTOR_INDEX_NAME ,
293+ dimensions = 3 ,
294+ path = "embedding" ,
295+ similarity = "cosine" ,
296+ filters = ["category" , "color" , "wheels" ],
297+ wait_until_complete = 120 ,
298+ )
299+ yield
300+ drop_vector_search_index (collection = coll , index_name = VECTOR_INDEX_NAME )
301+
302+ @pytest .fixture (scope = "class" , autouse = True )
303+ def sample_docs (self , preserved_collection : Collection , vector_search_index ):
304+ texts = ["apple fruit" , "banana fruit" , "car vehicle" , "bike vehicle" ]
305+ metadatas = [
306+ {"category" : "fruit" , "color" : "red" },
307+ {"category" : "fruit" , "color" : "yellow" },
308+ {"category" : "vehicle" , "wheels" : 4 },
309+ {"category" : "vehicle" , "wheels" : 2 },
310+ ]
311+
312+ def embeddings (texts ):
313+ mapping = {
314+ "apple fruit" : [1.0 , 0.5 , 0.0 ],
315+ "banana fruit" : [0.5 , 0.5 , 0.0 ],
316+ "car vehicle" : [0.0 , 0.5 , 1.0 ],
317+ "bike vehicle" : [0.0 , 1.0 , 0.5 ],
318+ }
319+ return [mapping [text ] for text in texts ]
320+
321+ bulk_embed_and_insert_texts (
322+ texts = texts ,
323+ metadatas = metadatas ,
324+ embedding_func = embeddings ,
325+ collection = preserved_collection ,
326+ text_key = "text" ,
327+ embedding_key = "embedding" ,
328+ )
329+ # Add a document that should not be returned in searches
330+ preserved_collection .insert_one (
331+ {
332+ "category" : "fruit" ,
333+ "color" : "red" ,
334+ "embedding" : [1.0 , 1.0 , 1.0 ],
335+ }
336+ )
337+ wait_for_docs_in_index (preserved_collection , VECTOR_INDEX_NAME , n_docs = 5 )
338+ return preserved_collection
339+
340+ def test_basic_search_query (self , sample_docs : Collection ):
341+ query_vector = [1.0 , 0.5 , 0.0 ]
342+
343+ result = execute_search_query (
344+ query_vector = query_vector ,
345+ collection = sample_docs ,
346+ embedding_key = "embedding" ,
347+ text_key = "text" ,
348+ index_name = VECTOR_INDEX_NAME ,
349+ k = 2 ,
350+ )
351+
352+ assert len (result ) == 2
353+ assert result [0 ]["text" ] == "apple fruit"
354+ assert result [1 ]["text" ] == "banana fruit"
355+ assert "score" in result [0 ]
356+ assert "score" in result [1 ]
357+
358+ def test_search_with_pre_filter (self , sample_docs : Collection ):
359+ query_vector = [1.0 , 0.5 , 1.0 ]
360+ pre_filter = {"category" : "fruit" }
361+
362+ result = execute_search_query (
363+ query_vector = query_vector ,
364+ collection = sample_docs ,
365+ embedding_key = "embedding" ,
366+ text_key = "text" ,
367+ index_name = VECTOR_INDEX_NAME ,
368+ k = 4 ,
369+ pre_filter = pre_filter ,
370+ )
371+
372+ assert len (result ) == 2
373+ assert result [0 ]["category" ] == "fruit"
374+ assert result [1 ]["category" ] == "fruit"
375+
376+ def test_search_with_post_filter_pipeline (self , sample_docs : Collection ):
377+ query_vector = [1.0 , 0.5 , 0.0 ]
378+ post_filter_pipeline = [
379+ {"$match" : {"score" : {"$gte" : 0.99 }}},
380+ {"$sort" : {"score" : - 1 }},
381+ ]
382+
383+ result = execute_search_query (
384+ query_vector = query_vector ,
385+ collection = sample_docs ,
386+ embedding_key = "embedding" ,
387+ text_key = "text" ,
388+ index_name = VECTOR_INDEX_NAME ,
389+ k = 2 ,
390+ post_filter_pipeline = post_filter_pipeline ,
391+ )
392+
393+ assert len (result ) == 1
394+
395+ def test_search_with_embeddings_included (self , sample_docs : Collection ):
396+ query_vector = [1.0 , 0.5 , 0.0 ]
397+
398+ result = execute_search_query (
399+ query_vector = query_vector ,
400+ collection = sample_docs ,
401+ embedding_key = "embedding" ,
402+ text_key = "text" ,
403+ index_name = VECTOR_INDEX_NAME ,
404+ k = 1 ,
405+ include_embeddings = True ,
406+ )
407+
408+ assert len (result ) == 1
409+ assert "embedding" in result [0 ]
410+ assert result [0 ]["embedding" ] == [1.0 , 0.5 , 0.0 ]
411+
412+ def test_search_with_custom_field_names (self , sample_docs : Collection ):
413+ query_vector = [1.0 , 0.5 , 0.25 ]
414+
415+ mock_cursor = [
416+ {
417+ "_id" : ObjectId (),
418+ "content" : "apple fruit" ,
419+ "vector" : [1.0 , 0.5 , 0.25 ],
420+ "score" : 0.9 ,
421+ }
422+ ]
423+
424+ with patch .object (sample_docs , "aggregate" ) as mock_aggregate :
425+ mock_aggregate .return_value = mock_cursor
426+
427+ result = execute_search_query (
428+ query_vector = query_vector ,
429+ collection = sample_docs ,
430+ embedding_key = "vector" ,
431+ text_key = "content" ,
432+ index_name = VECTOR_INDEX_NAME ,
433+ k = 1 ,
434+ )
435+
436+ assert len (result ) == 1
437+ assert "content" in result [0 ]
438+ assert result [0 ]["content" ] == "apple fruit"
439+
440+ pipeline_arg = mock_aggregate .call_args [0 ][0 ]
441+ vector_search_stage = pipeline_arg [0 ]["$vectorSearch" ]
442+ assert vector_search_stage ["path" ] == "vector"
443+ assert {"$project" : {"vector" : 0 }} in pipeline_arg
444+
445+ def test_search_filters_documents_without_text_key (self , sample_docs : Collection ):
446+ query_vector = [1.0 , 0.5 , 0.0 ]
447+
448+ result = execute_search_query (
449+ query_vector = query_vector ,
450+ collection = sample_docs ,
451+ embedding_key = "embedding" ,
452+ text_key = "text" ,
453+ index_name = VECTOR_INDEX_NAME ,
454+ k = 3 ,
455+ )
456+
457+ # Should only return documents with text field
458+ assert len (result ) == 2
459+ assert all ("text" in doc for doc in result )
460+ assert result [0 ]["text" ] == "apple fruit"
461+ assert result [1 ]["text" ] == "banana fruit"
0 commit comments