Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@
from __future__ import annotations

import asyncio
import inspect
import os
from collections.abc import Awaitable, Callable
from functools import lru_cache
from typing import Any
from typing import Any, cast

import openai
from pydantic import SecretStr


class _SyncHttpxClientWrapper(openai.DefaultHttpxClient):
Expand Down Expand Up @@ -107,3 +110,33 @@ def _get_default_async_httpx_client(
return _build_async_httpx_client(base_url, timeout)
else:
return _cached_async_httpx_client(base_url, timeout)


def _resolve_sync_and_async_api_keys(
api_key: SecretStr | Callable[[], str] | Callable[[], Awaitable[str]],
) -> tuple[str | None | Callable[[], str], str | Callable[[], Awaitable[str]]]:
"""Resolve sync and async API key values.

Because OpenAI and AsyncOpenAI clients support either sync or async callables for
the API key, we need to resolve separate values here.
"""
if isinstance(api_key, SecretStr):
sync_api_key_value: str | None | Callable[[], str] = api_key.get_secret_value()
async_api_key_value: str | Callable[[], Awaitable[str]] = (
api_key.get_secret_value()
)
elif callable(api_key):
if inspect.iscoroutinefunction(api_key):
async_api_key_value = api_key
sync_api_key_value = None
else:
sync_api_key_value = cast(Callable, api_key)

async def async_api_key_wrapper() -> str:
return await asyncio.get_event_loop().run_in_executor(
None, cast(Callable, api_key)
)

async_api_key_value = async_api_key_wrapper

return sync_api_key_value, async_api_key_value
133 changes: 110 additions & 23 deletions libs/partners/openai/langchain_openai/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@
import ssl
import sys
import warnings
from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence
from collections.abc import (
AsyncIterator,
Awaitable,
Callable,
Iterator,
Mapping,
Sequence,
)
from functools import partial
from io import BytesIO
from json import JSONDecodeError
Expand Down Expand Up @@ -109,6 +116,7 @@
from langchain_openai.chat_models._client_utils import (
_get_default_async_httpx_client,
_get_default_httpx_client,
_resolve_sync_and_async_api_keys,
)
from langchain_openai.chat_models._compat import (
_convert_from_v1_to_chat_completions,
Expand Down Expand Up @@ -465,9 +473,57 @@ class BaseChatOpenAI(BaseChatModel):
"""What sampling temperature to use."""
model_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
openai_api_key: SecretStr | None = Field(
openai_api_key: (
SecretStr | None | Callable[[], str] | Callable[[], Awaitable[str]]
) = Field(
alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None)
)
"""API key to use.

Can be inferred from the `OPENAI_API_KEY` environment variable, or specified as a
string, or sync or async callable that returns a string.

??? example "Specify with environment variable"

```bash
export OPENAI_API_KEY=...
```
```python
from langchain_openai import ChatOpenAI

model = ChatOpenAI(model="gpt-5-nano")
```

??? example "Specify with a string"

```python
from langchain_openai import ChatOpenAI

model = ChatOpenAI(model="gpt-5-nano", api_key="...")
```

??? example "Specify with a sync callable"
```python
from langchain_openai import ChatOpenAI

def get_api_key() -> str:
# Custom logic to retrieve API key
return "..."

model = ChatOpenAI(model="gpt-5-nano", api_key=get_api_key)
```

??? example "Specify with an async callable"
```python
from langchain_openai import ChatOpenAI

async def get_api_key() -> str:
# Custom async logic to retrieve API key
return "..."

model = ChatOpenAI(model="gpt-5-nano", api_key=get_api_key)
```
"""
openai_api_base: str | None = Field(default=None, alias="base_url")
"""Base URL path for API requests, leave blank if not using a proxy or service emulator.""" # noqa: E501
openai_organization: str | None = Field(default=None, alias="organization")
Expand Down Expand Up @@ -776,10 +832,18 @@ def validate_environment(self) -> Self:
):
self.stream_usage = True

# Resolve API key from SecretStr or Callable
sync_api_key_value: str | Callable[[], str] | None = None
async_api_key_value: str | Callable[[], Awaitable[str]] | None = None

if self.openai_api_key is not None:
# Because OpenAI and AsyncOpenAI clients support either sync or async
# callables for the API key, we need to resolve separate values here.
sync_api_key_value, async_api_key_value = _resolve_sync_and_async_api_keys(
self.openai_api_key
)

client_params: dict = {
"api_key": (
self.openai_api_key.get_secret_value() if self.openai_api_key else None
),
"organization": self.openai_organization,
"base_url": self.openai_api_base,
"timeout": self.request_timeout,
Expand All @@ -800,24 +864,33 @@ def validate_environment(self) -> Self:
)
raise ValueError(msg)
if not self.client:
if self.openai_proxy and not self.http_client:
try:
import httpx
except ImportError as e:
msg = (
"Could not import httpx python package. "
"Please install it with `pip install httpx`."
if sync_api_key_value is None:
# No valid sync API key, leave client as None and raise informative
# error on invocation.
self.client = None
self.root_client = None
else:
if self.openai_proxy and not self.http_client:
try:
import httpx
except ImportError as e:
msg = (
"Could not import httpx python package. "
"Please install it with `pip install httpx`."
)
raise ImportError(msg) from e
self.http_client = httpx.Client(
proxy=self.openai_proxy, verify=global_ssl_context
)
raise ImportError(msg) from e
self.http_client = httpx.Client(
proxy=self.openai_proxy, verify=global_ssl_context
)
sync_specific = {
"http_client": self.http_client
or _get_default_httpx_client(self.openai_api_base, self.request_timeout)
}
self.root_client = openai.OpenAI(**client_params, **sync_specific) # type: ignore[arg-type]
self.client = self.root_client.chat.completions
sync_specific = {
"http_client": self.http_client
or _get_default_httpx_client(
self.openai_api_base, self.request_timeout
),
"api_key": sync_api_key_value,
}
self.root_client = openai.OpenAI(**client_params, **sync_specific) # type: ignore[arg-type]
self.client = self.root_client.chat.completions
if not self.async_client:
if self.openai_proxy and not self.http_async_client:
try:
Expand All @@ -835,7 +908,8 @@ def validate_environment(self) -> Self:
"http_client": self.http_async_client
or _get_default_async_httpx_client(
self.openai_api_base, self.request_timeout
)
),
"api_key": async_api_key_value,
}
self.root_async_client = openai.AsyncOpenAI(
**client_params,
Expand Down Expand Up @@ -965,13 +1039,24 @@ def _convert_chunk_to_generation_chunk(
message=message_chunk, generation_info=generation_info or None
)

def _ensure_sync_client_available(self) -> None:
"""Check that sync client is available, raise error if not."""
if self.client is None:
msg = (
"Sync client is not available. This happens when an async callable "
"was provided for the API key. Use async methods (ainvoke, astream) "
"instead, or provide a string or sync callable for the API key."
)
raise ValueError(msg)

def _stream_responses(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
self._ensure_sync_client_available()
kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs)
if self.include_response_headers:
Expand Down Expand Up @@ -1101,6 +1186,7 @@ def _stream(
stream_usage: bool | None = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
self._ensure_sync_client_available()
kwargs["stream"] = True
stream_usage = self._should_stream_usage(stream_usage, **kwargs)
if stream_usage:
Expand Down Expand Up @@ -1169,6 +1255,7 @@ def _generate(
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
self._ensure_sync_client_available()
payload = self._get_request_payload(messages, stop=stop, **kwargs)
generation_info = None
raw_response = None
Expand Down
72 changes: 54 additions & 18 deletions libs/partners/openai/langchain_openai/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import logging
import warnings
from collections.abc import Iterable, Mapping, Sequence
from collections.abc import Awaitable, Callable, Iterable, Mapping, Sequence
from typing import Any, Literal, cast

import openai
Expand All @@ -15,6 +15,8 @@
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self

from langchain_openai.chat_models._client_utils import _resolve_sync_and_async_api_keys

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -189,7 +191,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
)
embedding_ctx_length: int = 8191
"""The maximum number of tokens to embed at once."""
openai_api_key: SecretStr | None = Field(
openai_api_key: (
SecretStr | None | Callable[[], str] | Callable[[], Awaitable[str]]
) = Field(
alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None)
)
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
Expand Down Expand Up @@ -292,10 +296,19 @@ def validate_environment(self) -> Self:
"If you are using Azure, please use the `AzureOpenAIEmbeddings` class."
)
raise ValueError(msg)

# Resolve API key from SecretStr or Callable
sync_api_key_value: str | Callable[[], str] | None = None
async_api_key_value: str | Callable[[], Awaitable[str]] | None = None

if self.openai_api_key is not None:
# Because OpenAI and AsyncOpenAI clients support either sync or async
# callables for the API key, we need to resolve separate values here.
sync_api_key_value, async_api_key_value = _resolve_sync_and_async_api_keys(
self.openai_api_key
)

client_params: dict = {
"api_key": (
self.openai_api_key.get_secret_value() if self.openai_api_key else None
),
"organization": self.openai_organization,
"base_url": self.openai_api_base,
"timeout": self.request_timeout,
Expand All @@ -315,18 +328,26 @@ def validate_environment(self) -> Self:
)
raise ValueError(msg)
if not self.client:
if self.openai_proxy and not self.http_client:
try:
import httpx
except ImportError as e:
msg = (
"Could not import httpx python package. "
"Please install it with `pip install httpx`."
)
raise ImportError(msg) from e
self.http_client = httpx.Client(proxy=self.openai_proxy)
sync_specific = {"http_client": self.http_client}
self.client = openai.OpenAI(**client_params, **sync_specific).embeddings # type: ignore[arg-type]
if sync_api_key_value is None:
# No valid sync API key, leave client as None and raise informative
# error on invocation.
self.client = None
else:
if self.openai_proxy and not self.http_client:
try:
import httpx
except ImportError as e:
msg = (
"Could not import httpx python package. "
"Please install it with `pip install httpx`."
)
raise ImportError(msg) from e
self.http_client = httpx.Client(proxy=self.openai_proxy)
sync_specific = {
"http_client": self.http_client,
"api_key": sync_api_key_value,
}
self.client = openai.OpenAI(**client_params, **sync_specific).embeddings # type: ignore[arg-type]
if not self.async_client:
if self.openai_proxy and not self.http_async_client:
try:
Expand All @@ -338,7 +359,10 @@ def validate_environment(self) -> Self:
)
raise ImportError(msg) from e
self.http_async_client = httpx.AsyncClient(proxy=self.openai_proxy)
async_specific = {"http_client": self.http_async_client}
async_specific = {
"http_client": self.http_async_client,
"api_key": async_api_key_value,
}
self.async_client = openai.AsyncOpenAI(
**client_params,
**async_specific, # type: ignore[arg-type]
Expand All @@ -352,6 +376,16 @@ def _invocation_params(self) -> dict[str, Any]:
params["dimensions"] = self.dimensions
return params

def _ensure_sync_client_available(self) -> None:
"""Check that sync client is available, raise error if not."""
if self.client is None:
msg = (
"Sync client is not available. This happens when an async callable "
"was provided for the API key. Use async methods (ainvoke, astream) "
"instead, or provide a string or sync callable for the API key."
)
raise ValueError(msg)

def _tokenize(
self, texts: list[str], chunk_size: int
) -> tuple[Iterable[int], list[list[int] | str], list[int]]:
Expand Down Expand Up @@ -571,6 +605,7 @@ def embed_documents(
Returns:
List of embeddings, one for each text.
"""
self._ensure_sync_client_available()
chunk_size_ = chunk_size or self.chunk_size
client_kwargs = {**self._invocation_params, **kwargs}
if not self.check_embedding_ctx_length:
Expand Down Expand Up @@ -635,6 +670,7 @@ def embed_query(self, text: str, **kwargs: Any) -> list[float]:
Returns:
Embedding for the text.
"""
self._ensure_sync_client_available()
return self.embed_documents([text], **kwargs)[0]

async def aembed_query(self, text: str, **kwargs: Any) -> list[float]:
Expand Down
Loading