diff --git a/.gitignore b/.gitignore index 16be36b7155..6e299f5f8e6 100644 --- a/.gitignore +++ b/.gitignore @@ -45,3 +45,6 @@ target/ # Rapid test data testdata + +# Examples test data +examples/gemini*/chroma_storage/ diff --git a/examples/gemini/load_data.py b/examples/gemini/load_data.py index bac0b19e324..ef0f40899ce 100644 --- a/examples/gemini/load_data.py +++ b/examples/gemini/load_data.py @@ -14,6 +14,7 @@ def main( persist_directory: str = ".", ) -> None: # Read all files in the data directory + ids = [] documents = [] metadatas = [] files = os.listdir(documents_directory) @@ -27,6 +28,7 @@ def main( # Skip empty lines if len(line) == 0: continue + ids.append(str(len(ids))) # unique ID for the document documents.append(line) metadatas.append({"filename": filename, "line_number": line_number}) @@ -43,8 +45,9 @@ def main( google_api_key = os.environ["GOOGLE_API_KEY"] # create embedding function - embedding_function = embedding_functions.GoogleGenerativeAIEmbeddingFunction( - api_key=google_api_key + embedding_function = embedding_functions.GoogleGenerativeAiEmbeddingFunction( + api_key=google_api_key, + model_name='gemini-embedding-001', ) # If the collection already exists, we just return it. This allows us to add more @@ -56,16 +59,16 @@ def main( # Create ids from the current count count = collection.count() print(f"Collection already contains {count} documents") - ids = [str(i) for i in range(count, count + len(documents))] - # Load the documents in batches of 100 + # Load the documents in batches + batch_size = 5 # Using small batch size to work better with Free Tier for i in tqdm( - range(0, len(documents), 100), desc="Adding documents", unit_scale=100 + range(count, len(documents), batch_size), desc="Adding documents", unit_scale=batch_size ): collection.add( - ids=ids[i : i + 100], - documents=documents[i : i + 100], - metadatas=metadatas[i : i + 100], # type: ignore + ids=ids[i : i + batch_size], + documents=documents[i : i + batch_size], + metadatas=metadatas[i : i + batch_size], # type: ignore ) new_count = collection.count() diff --git a/examples/gemini/main.py b/examples/gemini/main.py index c2d44bbad9b..4396ba65cb8 100644 --- a/examples/gemini/main.py +++ b/examples/gemini/main.py @@ -6,7 +6,7 @@ import chromadb from chromadb.utils import embedding_functions -model = genai.GenerativeModel("gemini-pro") +model = genai.GenerativeModel('gemini-2.5-flash') def build_prompt(query: str, context: List[str]) -> str: @@ -78,8 +78,9 @@ def main( client = chromadb.PersistentClient(path=persist_directory) # create embedding function - embedding_function = embedding_functions.GoogleGenerativeAIEmbeddingFunction( - api_key=google_api_key, task_type="RETRIEVAL_QUERY" + embedding_function = embedding_functions.GoogleGenerativeAiEmbeddingFunction( + api_key=google_api_key, + model_name='gemini-embedding-001', ) # Get the collection.