-
Notifications
You must be signed in to change notification settings - Fork 5.1k
Migrate AzureOpenAI constructors to standard OpenAI client #2752
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 13 commits
6d3d0fc
ee85256
7fd57a6
76e5cb6
5d23055
4bc8d18
5d428c8
617ba74
d4180ad
f91fad1
27ec374
b47ea1b
1c40e29
5e8723e
0c7aa51
bacf427
01b56f9
ee14c51
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,24 +2,21 @@ | |
| import asyncio | ||
| import logging | ||
| import os | ||
| from collections.abc import Awaitable, Callable | ||
| from enum import Enum | ||
| from typing import Optional | ||
|
|
||
| import aiohttp | ||
| from azure.core.credentials import AzureKeyCredential | ||
| from azure.core.credentials_async import AsyncTokenCredential | ||
| from azure.identity.aio import AzureDeveloperCliCredential, get_bearer_token_provider | ||
| from openai import AsyncAzureOpenAI, AsyncOpenAI | ||
| from openai import AsyncOpenAI | ||
| from rich.logging import RichHandler | ||
|
|
||
| from load_azd_env import load_azd_env | ||
| from prepdocslib.blobmanager import BlobManager | ||
| from prepdocslib.csvparser import CsvParser | ||
| from prepdocslib.embeddings import ( | ||
| AzureOpenAIEmbeddingService, | ||
| ImageEmbeddings, | ||
| OpenAIEmbeddingService, | ||
| ) | ||
| from prepdocslib.embeddings import ImageEmbeddings, OpenAIEmbeddings | ||
| from prepdocslib.fileprocessor import FileProcessor | ||
| from prepdocslib.filestrategy import FileStrategy | ||
| from prepdocslib.htmlparser import LocalHTMLParser | ||
|
|
@@ -160,17 +157,12 @@ class OpenAIHost(str, Enum): | |
|
|
||
|
|
||
| def setup_embeddings_service( | ||
| azure_credential: AsyncTokenCredential, | ||
| open_ai_client: AsyncOpenAI, | ||
| openai_host: OpenAIHost, | ||
| emb_model_name: str, | ||
| emb_model_dimensions: int, | ||
| azure_openai_service: Optional[str], | ||
| azure_openai_custom_url: Optional[str], | ||
| azure_openai_deployment: Optional[str], | ||
| azure_openai_key: Optional[str], | ||
| azure_openai_api_version: str, | ||
| openai_key: Optional[str], | ||
| openai_org: Optional[str], | ||
| azure_openai_deployment: str | None, | ||
| azure_openai_endpoint: str | None, | ||
| disable_vectors: bool = False, | ||
| disable_batch_vectors: bool = False, | ||
| ): | ||
|
|
@@ -179,70 +171,59 @@ def setup_embeddings_service( | |
| return None | ||
|
|
||
| if openai_host in [OpenAIHost.AZURE, OpenAIHost.AZURE_CUSTOM]: | ||
| azure_open_ai_credential: AsyncTokenCredential | AzureKeyCredential = ( | ||
| azure_credential if azure_openai_key is None else AzureKeyCredential(azure_openai_key) | ||
| ) | ||
| return AzureOpenAIEmbeddingService( | ||
| open_ai_service=azure_openai_service, | ||
| open_ai_custom_url=azure_openai_custom_url, | ||
| open_ai_deployment=azure_openai_deployment, | ||
| open_ai_model_name=emb_model_name, | ||
| open_ai_dimensions=emb_model_dimensions, | ||
| open_ai_api_version=azure_openai_api_version, | ||
| credential=azure_open_ai_credential, | ||
| disable_batch=disable_batch_vectors, | ||
| ) | ||
| else: | ||
| if openai_key is None: | ||
| raise ValueError("OpenAI key is required when using the non-Azure OpenAI API") | ||
| return OpenAIEmbeddingService( | ||
| open_ai_model_name=emb_model_name, | ||
| open_ai_dimensions=emb_model_dimensions, | ||
| credential=openai_key, | ||
| organization=openai_org, | ||
| disable_batch=disable_batch_vectors, | ||
| ) | ||
| if azure_openai_endpoint is None: | ||
| raise ValueError("Azure OpenAI endpoint must be provided when using Azure OpenAI embeddings") | ||
| if azure_openai_deployment is None: | ||
| raise ValueError("Azure OpenAI deployment must be provided when using Azure OpenAI embeddings") | ||
|
|
||
| return OpenAIEmbeddings( | ||
| open_ai_client=open_ai_client, | ||
| open_ai_model_name=emb_model_name, | ||
| open_ai_dimensions=emb_model_dimensions, | ||
| disable_batch=disable_batch_vectors, | ||
| azure_deployment_name=azure_openai_deployment, | ||
| azure_endpoint=azure_openai_endpoint, | ||
| ) | ||
|
|
||
|
|
||
| def setup_openai_client( | ||
| openai_host: OpenAIHost, | ||
| azure_credential: AsyncTokenCredential, | ||
| azure_openai_api_key: Optional[str] = None, | ||
| azure_openai_api_version: Optional[str] = None, | ||
| azure_openai_service: Optional[str] = None, | ||
| azure_openai_custom_url: Optional[str] = None, | ||
| openai_api_key: Optional[str] = None, | ||
| openai_organization: Optional[str] = None, | ||
| ): | ||
| if openai_host not in OpenAIHost: | ||
| raise ValueError(f"Invalid OPENAI_HOST value: {openai_host}. Must be one of {[h.value for h in OpenAIHost]}.") | ||
|
|
||
| ) -> tuple[AsyncOpenAI, Optional[str]]: | ||
| openai_client: AsyncOpenAI | ||
| azure_openai_endpoint: Optional[str] = None | ||
|
|
||
| if openai_host in [OpenAIHost.AZURE, OpenAIHost.AZURE_CUSTOM]: | ||
| base_url: Optional[str] = None | ||
| api_key_or_token: Optional[str | Callable[[], Awaitable[str]]] = None | ||
| if openai_host == OpenAIHost.AZURE_CUSTOM: | ||
| logger.info("OPENAI_HOST is azure_custom, setting up Azure OpenAI custom client") | ||
| if not azure_openai_custom_url: | ||
| raise ValueError("AZURE_OPENAI_CUSTOM_URL must be set when OPENAI_HOST is azure_custom") | ||
| endpoint = azure_openai_custom_url | ||
| base_url = azure_openai_custom_url | ||
| else: | ||
| logger.info("OPENAI_HOST is azure, setting up Azure OpenAI client") | ||
| if not azure_openai_service: | ||
| raise ValueError("AZURE_OPENAI_SERVICE must be set when OPENAI_HOST is azure") | ||
| endpoint = f"https://{azure_openai_service}.openai.azure.com" | ||
| azure_openai_endpoint = f"https://{azure_openai_service}.openai.azure.com" | ||
| base_url = f"{azure_openai_endpoint}/openai/v1" | ||
| if azure_openai_api_key: | ||
| logger.info("AZURE_OPENAI_API_KEY_OVERRIDE found, using as api_key for Azure OpenAI client") | ||
| openai_client = AsyncAzureOpenAI( | ||
| api_version=azure_openai_api_version, azure_endpoint=endpoint, api_key=azure_openai_api_key | ||
| ) | ||
| api_key_or_token = azure_openai_api_key | ||
| else: | ||
| logger.info("Using Azure credential (passwordless authentication) for Azure OpenAI client") | ||
| token_provider = get_bearer_token_provider(azure_credential, "https://cognitiveservices.azure.com/.default") | ||
| openai_client = AsyncAzureOpenAI( | ||
| api_version=azure_openai_api_version, | ||
| azure_endpoint=endpoint, | ||
| azure_ad_token_provider=token_provider, | ||
| api_key_or_token = get_bearer_token_provider( | ||
| azure_credential, "https://cognitiveservices.azure.com/.default" | ||
| ) | ||
| openai_client = AsyncOpenAI( | ||
| base_url=base_url, | ||
| api_key=api_key_or_token, # type: ignore[arg-type] | ||
| ) | ||
| elif openai_host == OpenAIHost.LOCAL: | ||
| logger.info("OPENAI_HOST is local, setting up local OpenAI client for OPENAI_BASE_URL with no key") | ||
| openai_client = AsyncOpenAI( | ||
|
|
@@ -259,7 +240,7 @@ def setup_openai_client( | |
| api_key=openai_api_key, | ||
| organization=openai_organization, | ||
| ) | ||
| return openai_client | ||
| return openai_client, azure_openai_endpoint | ||
|
|
||
|
|
||
| def setup_file_processors( | ||
|
|
@@ -368,7 +349,7 @@ async def main(strategy: Strategy, setup_index: bool = True): | |
| await strategy.run() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| if __name__ == "__main__": # pragma: no cover | ||
| parser = argparse.ArgumentParser( | ||
| description="Prepare documents by extracting content from PDFs, splitting content into sections, uploading to blob storage, and indexing in a search index." | ||
| ) | ||
|
|
@@ -516,41 +497,34 @@ async def main(strategy: Strategy, setup_index: bool = True): | |
| enable_global_documents=enable_global_documents, | ||
| ) | ||
|
|
||
| # https://learn.microsoft.com/azure/ai-services/openai/api-version-deprecation#latest-ga-api-release | ||
| azure_openai_api_version = os.getenv("AZURE_OPENAI_API_VERSION") or "2024-06-01" | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @copilot Isn't AZURE_OPENAI_API_VERSION in other files as well, like the GitHub actions workflows? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're absolutely right! Found and removed AZURE_OPENAI_API_VERSION references from GitHub Actions workflows (.github/workflows/azure-dev.yml, .github/workflows/evaluate.yaml), Azure DevOps pipeline (.azdo/pipelines/azure-dev.yml), and documentation (docs/reasoning.md). The only remaining reference is in evals/generate_ground_truth.py which uses LangChain wrappers, not the OpenAI SDK directly. Commit: 5d23055 |
||
| emb_model_dimensions = 1536 | ||
| if os.getenv("AZURE_OPENAI_EMB_DIMENSIONS"): | ||
| emb_model_dimensions = int(os.environ["AZURE_OPENAI_EMB_DIMENSIONS"]) | ||
| openai_embeddings_service = setup_embeddings_service( | ||
| azure_credential=azd_credential, | ||
| openai_host=OPENAI_HOST, | ||
| emb_model_name=os.environ["AZURE_OPENAI_EMB_MODEL_NAME"], | ||
| emb_model_dimensions=emb_model_dimensions, | ||
| azure_openai_service=os.getenv("AZURE_OPENAI_SERVICE"), | ||
| azure_openai_custom_url=os.getenv("AZURE_OPENAI_CUSTOM_URL"), | ||
| azure_openai_deployment=os.getenv("AZURE_OPENAI_EMB_DEPLOYMENT"), | ||
| azure_openai_api_version=azure_openai_api_version, | ||
| azure_openai_key=os.getenv("AZURE_OPENAI_API_KEY_OVERRIDE"), | ||
| openai_key=clean_key_if_exists(os.getenv("OPENAI_API_KEY")), | ||
| openai_org=os.getenv("OPENAI_ORGANIZATION"), | ||
| disable_vectors=dont_use_vectors, | ||
| disable_batch_vectors=args.disablebatchvectors, | ||
| ) | ||
| openai_client = setup_openai_client( | ||
|
|
||
| openai_client, azure_openai_endpoint = setup_openai_client( | ||
| openai_host=OPENAI_HOST, | ||
| azure_credential=azd_credential, | ||
| azure_openai_api_version=azure_openai_api_version, | ||
| azure_openai_service=os.getenv("AZURE_OPENAI_SERVICE"), | ||
| azure_openai_custom_url=os.getenv("AZURE_OPENAI_CUSTOM_URL"), | ||
| azure_openai_api_key=os.getenv("AZURE_OPENAI_API_KEY_OVERRIDE"), | ||
| openai_api_key=clean_key_if_exists(os.getenv("OPENAI_API_KEY")), | ||
| openai_organization=os.getenv("OPENAI_ORGANIZATION"), | ||
| ) | ||
| openai_embeddings_service = setup_embeddings_service( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moved after, since its now dependent on the client There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. makes sense, and it's simpler at the end of the day |
||
| open_ai_client=openai_client, | ||
| openai_host=OPENAI_HOST, | ||
| emb_model_name=os.environ["AZURE_OPENAI_EMB_MODEL_NAME"], | ||
| emb_model_dimensions=emb_model_dimensions, | ||
| azure_openai_deployment=os.getenv("AZURE_OPENAI_EMB_DEPLOYMENT"), | ||
| azure_openai_endpoint=azure_openai_endpoint, | ||
| disable_vectors=dont_use_vectors, | ||
| disable_batch_vectors=args.disablebatchvectors, | ||
| ) | ||
|
|
||
| ingestion_strategy: Strategy | ||
| if use_int_vectorization: | ||
|
|
||
| if not openai_embeddings_service or not isinstance(openai_embeddings_service, AzureOpenAIEmbeddingService): | ||
| if not openai_embeddings_service or OPENAI_HOST not in [OpenAIHost.AZURE, OpenAIHost.AZURE_CUSTOM]: | ||
| raise Exception("Integrated vectorization strategy requires an Azure OpenAI embeddings service") | ||
|
|
||
| ingestion_strategy = IntegratedVectorizerStrategy( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@copilot Remove this line that sets AZURE_OPENAI_ENDPOINT, and pass azure_openai_endpoint from
setup_openai_clientto setup_embeddings insteadThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in commit ee14c51 - removed the redundant AZURE_OPENAI_ENDPOINT computation and now pass the
azure_openai_endpointvalue returned fromsetup_openai_clientdirectly tosetup_embeddings_service.