22import asyncio
33import logging
44import os
5+ from collections .abc import Awaitable , Callable
56from enum import Enum
67from typing import Optional
78
89import aiohttp
910from azure .core .credentials import AzureKeyCredential
1011from azure .core .credentials_async import AsyncTokenCredential
1112from azure .identity .aio import AzureDeveloperCliCredential , get_bearer_token_provider
12- from openai import AsyncAzureOpenAI , AsyncOpenAI
13+ from openai import AsyncOpenAI
1314from rich .logging import RichHandler
1415
1516from load_azd_env import load_azd_env
1617from prepdocslib .blobmanager import BlobManager
1718from prepdocslib .csvparser import CsvParser
18- from prepdocslib .embeddings import (
19- AzureOpenAIEmbeddingService ,
20- ImageEmbeddings ,
21- OpenAIEmbeddingService ,
22- )
19+ from prepdocslib .embeddings import ImageEmbeddings , OpenAIEmbeddings
2320from prepdocslib .fileprocessor import FileProcessor
2421from prepdocslib .filestrategy import FileStrategy
2522from prepdocslib .htmlparser import LocalHTMLParser
@@ -160,17 +157,12 @@ class OpenAIHost(str, Enum):
160157
161158
162159def setup_embeddings_service (
163- azure_credential : AsyncTokenCredential ,
160+ open_ai_client : AsyncOpenAI ,
164161 openai_host : OpenAIHost ,
165162 emb_model_name : str ,
166163 emb_model_dimensions : int ,
167- azure_openai_service : Optional [str ],
168- azure_openai_custom_url : Optional [str ],
169- azure_openai_deployment : Optional [str ],
170- azure_openai_key : Optional [str ],
171- azure_openai_api_version : str ,
172- openai_key : Optional [str ],
173- openai_org : Optional [str ],
164+ azure_openai_deployment : str | None ,
165+ azure_openai_endpoint : str | None ,
174166 disable_vectors : bool = False ,
175167 disable_batch_vectors : bool = False ,
176168):
@@ -179,70 +171,59 @@ def setup_embeddings_service(
179171 return None
180172
181173 if openai_host in [OpenAIHost .AZURE , OpenAIHost .AZURE_CUSTOM ]:
182- azure_open_ai_credential : AsyncTokenCredential | AzureKeyCredential = (
183- azure_credential if azure_openai_key is None else AzureKeyCredential (azure_openai_key )
184- )
185- return AzureOpenAIEmbeddingService (
186- open_ai_service = azure_openai_service ,
187- open_ai_custom_url = azure_openai_custom_url ,
188- open_ai_deployment = azure_openai_deployment ,
189- open_ai_model_name = emb_model_name ,
190- open_ai_dimensions = emb_model_dimensions ,
191- open_ai_api_version = azure_openai_api_version ,
192- credential = azure_open_ai_credential ,
193- disable_batch = disable_batch_vectors ,
194- )
195- else :
196- if openai_key is None :
197- raise ValueError ("OpenAI key is required when using the non-Azure OpenAI API" )
198- return OpenAIEmbeddingService (
199- open_ai_model_name = emb_model_name ,
200- open_ai_dimensions = emb_model_dimensions ,
201- credential = openai_key ,
202- organization = openai_org ,
203- disable_batch = disable_batch_vectors ,
204- )
174+ if azure_openai_endpoint is None :
175+ raise ValueError ("Azure OpenAI endpoint must be provided when using Azure OpenAI embeddings" )
176+ if azure_openai_deployment is None :
177+ raise ValueError ("Azure OpenAI deployment must be provided when using Azure OpenAI embeddings" )
178+
179+ return OpenAIEmbeddings (
180+ open_ai_client = open_ai_client ,
181+ open_ai_model_name = emb_model_name ,
182+ open_ai_dimensions = emb_model_dimensions ,
183+ disable_batch = disable_batch_vectors ,
184+ azure_deployment_name = azure_openai_deployment ,
185+ azure_endpoint = azure_openai_endpoint ,
186+ )
205187
206188
207189def setup_openai_client (
208190 openai_host : OpenAIHost ,
209191 azure_credential : AsyncTokenCredential ,
210192 azure_openai_api_key : Optional [str ] = None ,
211- azure_openai_api_version : Optional [str ] = None ,
212193 azure_openai_service : Optional [str ] = None ,
213194 azure_openai_custom_url : Optional [str ] = None ,
214195 openai_api_key : Optional [str ] = None ,
215196 openai_organization : Optional [str ] = None ,
216- ):
217- if openai_host not in OpenAIHost :
218- raise ValueError (f"Invalid OPENAI_HOST value: { openai_host } . Must be one of { [h .value for h in OpenAIHost ]} ." )
219-
197+ ) -> tuple [AsyncOpenAI , Optional [str ]]:
220198 openai_client : AsyncOpenAI
199+ azure_openai_endpoint : Optional [str ] = None
221200
222201 if openai_host in [OpenAIHost .AZURE , OpenAIHost .AZURE_CUSTOM ]:
202+ base_url : Optional [str ] = None
203+ api_key_or_token : Optional [str | Callable [[], Awaitable [str ]]] = None
223204 if openai_host == OpenAIHost .AZURE_CUSTOM :
224205 logger .info ("OPENAI_HOST is azure_custom, setting up Azure OpenAI custom client" )
225206 if not azure_openai_custom_url :
226207 raise ValueError ("AZURE_OPENAI_CUSTOM_URL must be set when OPENAI_HOST is azure_custom" )
227- endpoint = azure_openai_custom_url
208+ base_url = azure_openai_custom_url
228209 else :
229210 logger .info ("OPENAI_HOST is azure, setting up Azure OpenAI client" )
230211 if not azure_openai_service :
231212 raise ValueError ("AZURE_OPENAI_SERVICE must be set when OPENAI_HOST is azure" )
232- endpoint = f"https://{ azure_openai_service } .openai.azure.com"
213+ azure_openai_endpoint = f"https://{ azure_openai_service } .openai.azure.com"
214+ base_url = f"{ azure_openai_endpoint } /openai/v1"
233215 if azure_openai_api_key :
234216 logger .info ("AZURE_OPENAI_API_KEY_OVERRIDE found, using as api_key for Azure OpenAI client" )
235- openai_client = AsyncAzureOpenAI (
236- api_version = azure_openai_api_version , azure_endpoint = endpoint , api_key = azure_openai_api_key
237- )
217+ api_key_or_token = azure_openai_api_key
238218 else :
239219 logger .info ("Using Azure credential (passwordless authentication) for Azure OpenAI client" )
240- token_provider = get_bearer_token_provider (azure_credential , "https://cognitiveservices.azure.com/.default" )
241- openai_client = AsyncAzureOpenAI (
242- api_version = azure_openai_api_version ,
243- azure_endpoint = endpoint ,
244- azure_ad_token_provider = token_provider ,
220+ api_key_or_token = get_bearer_token_provider (
221+ azure_credential , "https://cognitiveservices.azure.com/.default"
245222 )
223+ openai_client = AsyncOpenAI (
224+ base_url = base_url ,
225+ api_key = api_key_or_token , # type: ignore[arg-type]
226+ )
246227 elif openai_host == OpenAIHost .LOCAL :
247228 logger .info ("OPENAI_HOST is local, setting up local OpenAI client for OPENAI_BASE_URL with no key" )
248229 openai_client = AsyncOpenAI (
@@ -259,7 +240,7 @@ def setup_openai_client(
259240 api_key = openai_api_key ,
260241 organization = openai_organization ,
261242 )
262- return openai_client
243+ return openai_client , azure_openai_endpoint
263244
264245
265246def setup_file_processors (
@@ -368,7 +349,7 @@ async def main(strategy: Strategy, setup_index: bool = True):
368349 await strategy .run ()
369350
370351
371- if __name__ == "__main__" :
352+ if __name__ == "__main__" : # pragma: no cover
372353 parser = argparse .ArgumentParser (
373354 description = "Prepare documents by extracting content from PDFs, splitting content into sections, uploading to blob storage, and indexing in a search index."
374355 )
@@ -516,41 +497,34 @@ async def main(strategy: Strategy, setup_index: bool = True):
516497 enable_global_documents = enable_global_documents ,
517498 )
518499
519- # https://learn.microsoft.com/azure/ai-services/openai/api-version-deprecation#latest-ga-api-release
520- azure_openai_api_version = os .getenv ("AZURE_OPENAI_API_VERSION" ) or "2024-06-01"
521500 emb_model_dimensions = 1536
522501 if os .getenv ("AZURE_OPENAI_EMB_DIMENSIONS" ):
523502 emb_model_dimensions = int (os .environ ["AZURE_OPENAI_EMB_DIMENSIONS" ])
524- openai_embeddings_service = setup_embeddings_service (
525- azure_credential = azd_credential ,
526- openai_host = OPENAI_HOST ,
527- emb_model_name = os .environ ["AZURE_OPENAI_EMB_MODEL_NAME" ],
528- emb_model_dimensions = emb_model_dimensions ,
529- azure_openai_service = os .getenv ("AZURE_OPENAI_SERVICE" ),
530- azure_openai_custom_url = os .getenv ("AZURE_OPENAI_CUSTOM_URL" ),
531- azure_openai_deployment = os .getenv ("AZURE_OPENAI_EMB_DEPLOYMENT" ),
532- azure_openai_api_version = azure_openai_api_version ,
533- azure_openai_key = os .getenv ("AZURE_OPENAI_API_KEY_OVERRIDE" ),
534- openai_key = clean_key_if_exists (os .getenv ("OPENAI_API_KEY" )),
535- openai_org = os .getenv ("OPENAI_ORGANIZATION" ),
536- disable_vectors = dont_use_vectors ,
537- disable_batch_vectors = args .disablebatchvectors ,
538- )
539- openai_client = setup_openai_client (
503+
504+ openai_client , azure_openai_endpoint = setup_openai_client (
540505 openai_host = OPENAI_HOST ,
541506 azure_credential = azd_credential ,
542- azure_openai_api_version = azure_openai_api_version ,
543507 azure_openai_service = os .getenv ("AZURE_OPENAI_SERVICE" ),
544508 azure_openai_custom_url = os .getenv ("AZURE_OPENAI_CUSTOM_URL" ),
545509 azure_openai_api_key = os .getenv ("AZURE_OPENAI_API_KEY_OVERRIDE" ),
546510 openai_api_key = clean_key_if_exists (os .getenv ("OPENAI_API_KEY" )),
547511 openai_organization = os .getenv ("OPENAI_ORGANIZATION" ),
548512 )
513+ openai_embeddings_service = setup_embeddings_service (
514+ open_ai_client = openai_client ,
515+ openai_host = OPENAI_HOST ,
516+ emb_model_name = os .environ ["AZURE_OPENAI_EMB_MODEL_NAME" ],
517+ emb_model_dimensions = emb_model_dimensions ,
518+ azure_openai_deployment = os .getenv ("AZURE_OPENAI_EMB_DEPLOYMENT" ),
519+ azure_openai_endpoint = azure_openai_endpoint ,
520+ disable_vectors = dont_use_vectors ,
521+ disable_batch_vectors = args .disablebatchvectors ,
522+ )
549523
550524 ingestion_strategy : Strategy
551525 if use_int_vectorization :
552526
553- if not openai_embeddings_service or not isinstance ( openai_embeddings_service , AzureOpenAIEmbeddingService ) :
527+ if not openai_embeddings_service or OPENAI_HOST not in [ OpenAIHost . AZURE , OpenAIHost . AZURE_CUSTOM ] :
554528 raise Exception ("Integrated vectorization strategy requires an Azure OpenAI embeddings service" )
555529
556530 ingestion_strategy = IntegratedVectorizerStrategy (
0 commit comments