Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import logging
from functools import wraps
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional

from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import generate_from_stream
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatResult
from langchain_nvidia_ai_endpoints import ChatNVIDIA as ChatNVIDIAOriginal
from pydantic.v1 import Field
from pydantic import Field

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -55,6 +56,46 @@ class ChatNVIDIA(ChatNVIDIAOriginal):
streaming: bool = Field(
default=False, description="Whether to use streaming or not"
)
custom_headers: Optional[Dict[str, str]] = Field(
default=None, description="Custom HTTP headers to send with requests"
)

def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
if self.custom_headers:
if not hasattr(self._client, "get_req"):
raise RuntimeError(
"custom_headers requires langchain-nvidia-ai-endpoints >= 0.2.1. "
"Your version uses a nested client structure that is not supported. "
"Please upgrade: pip install --upgrade langchain-nvidia-ai-endpoints>=0.3.0"
)

sig = inspect.signature(self._client.get_req)
if "extra_headers" not in sig.parameters:
raise RuntimeError(
"custom_headers requires langchain-nvidia-ai-endpoints >= 0.3.0. "
"Your version does not support the extra_headers parameter. "
"Please upgrade: pip install --upgrade langchain-nvidia-ai-endpoints>=0.3.0"
)

self._wrap_client_methods()

def _wrap_client_methods(self):
original_get_req = self._client.get_req
original_get_req_stream = self._client.get_req_stream

def wrapped_get_req(payload: dict = {}, extra_headers: dict = {}):
merged_headers = {**extra_headers, **self.custom_headers}
return original_get_req(payload=payload, extra_headers=merged_headers)

def wrapped_get_req_stream(payload: dict, extra_headers: dict = {}):
merged_headers = {**extra_headers, **self.custom_headers}
return original_get_req_stream(
payload=payload, extra_headers=merged_headers
)

object.__setattr__(self._client, "get_req", wrapped_get_req)
object.__setattr__(self._client, "get_req_stream", wrapped_get_req_stream)

@stream_decorator
def _generate(
Expand Down
Loading