Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import logging
from functools import wraps
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional

from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import generate_from_stream
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatResult
from langchain_nvidia_ai_endpoints import ChatNVIDIA as ChatNVIDIAOriginal
from pydantic.v1 import Field
from pydantic import Field

log = logging.getLogger(__name__)
log = logging.getLogger(__name__) # pragma: no cover


def stream_decorator(func):
def stream_decorator(func): # pragma: no cover
@wraps(func)
def wrapper(
self,
Expand All @@ -51,10 +52,54 @@ def wrapper(

# NOTE: this needs to have the same name as the original class,
# otherwise, there's a check inside `langchain-nvidia-ai-endpoints` that will fail.
class ChatNVIDIA(ChatNVIDIAOriginal):
class ChatNVIDIA(ChatNVIDIAOriginal): # pragma: no cover
streaming: bool = Field(
default=False, description="Whether to use streaming or not"
)
custom_headers: Optional[Dict[str, str]] = Field(
default=None, description="Custom HTTP headers to send with requests"
)

def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
if self.custom_headers:
super().__init__(**kwargs)
custom_headers_error = (
"custom_headers requires langchain-nvidia-ai-endpoints >= 0.3.0. "
"Your version does not support the required client structure or "
"extra_headers parameter. Please upgrade: "
"pip install --upgrade langchain-nvidia-ai-endpoints>=0.3.0"
)
if self.custom_headers:
if not hasattr(self._client, "get_req"):
raise RuntimeError(custom_headers_error)

sig = inspect.signature(self._client.get_req)
if "extra_headers" not in sig.parameters:
raise RuntimeError(custom_headers_error)

self._wrap_client_methods()

def _wrap_client_methods(self):
original_get_req = self._client.get_req
original_get_req_stream = self._client.get_req_stream

def wrapped_get_req(payload: dict = None, extra_headers: dict = None):
payload = payload or {}
extra_headers = extra_headers or {}
merged_headers = {**extra_headers, **self.custom_headers}
return original_get_req(payload=payload, extra_headers=merged_headers)

def wrapped_get_req_stream(payload: dict = None, extra_headers: dict = None):
payload = payload or {}
extra_headers = extra_headers or {}
merged_headers = {**extra_headers, **self.custom_headers}
return original_get_req_stream(
payload=payload, extra_headers=merged_headers
)

object.__setattr__(self._client, "get_req", wrapped_get_req)
object.__setattr__(self._client, "get_req_stream", wrapped_get_req_stream)

@stream_decorator
def _generate(
Expand Down
Loading