|
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 |
|
| 16 | +import inspect |
16 | 17 | import logging |
17 | 18 | from functools import wraps |
18 | | -from typing import Any, List, Optional |
| 19 | +from typing import Any, Dict, List, Optional |
19 | 20 |
|
20 | 21 | from langchain_core.callbacks.manager import CallbackManagerForLLMRun |
21 | 22 | from langchain_core.language_models.chat_models import generate_from_stream |
22 | 23 | from langchain_core.messages import BaseMessage |
23 | 24 | from langchain_core.outputs import ChatResult |
24 | 25 | from langchain_nvidia_ai_endpoints import ChatNVIDIA as ChatNVIDIAOriginal |
25 | | -from pydantic.v1 import Field |
| 26 | +from pydantic import Field |
26 | 27 |
|
27 | 28 | log = logging.getLogger(__name__) |
28 | 29 |
|
@@ -55,6 +56,46 @@ class ChatNVIDIA(ChatNVIDIAOriginal): |
55 | 56 | streaming: bool = Field( |
56 | 57 | default=False, description="Whether to use streaming or not" |
57 | 58 | ) |
| 59 | + custom_headers: Optional[Dict[str, str]] = Field( |
| 60 | + default=None, description="Custom HTTP headers to send with requests" |
| 61 | + ) |
| 62 | + |
| 63 | + def __init__(self, **kwargs: Any): |
| 64 | + super().__init__(**kwargs) |
| 65 | + if self.custom_headers: |
| 66 | + if not hasattr(self._client, "get_req"): |
| 67 | + raise RuntimeError( |
| 68 | + "custom_headers requires langchain-nvidia-ai-endpoints >= 0.2.1. " |
| 69 | + "Your version uses a nested client structure that is not supported. " |
| 70 | + "Please upgrade: pip install --upgrade langchain-nvidia-ai-endpoints>=0.3.0" |
| 71 | + ) |
| 72 | + |
| 73 | + sig = inspect.signature(self._client.get_req) |
| 74 | + if "extra_headers" not in sig.parameters: |
| 75 | + raise RuntimeError( |
| 76 | + "custom_headers requires langchain-nvidia-ai-endpoints >= 0.3.0. " |
| 77 | + "Your version does not support the extra_headers parameter. " |
| 78 | + "Please upgrade: pip install --upgrade langchain-nvidia-ai-endpoints>=0.3.0" |
| 79 | + ) |
| 80 | + |
| 81 | + self._wrap_client_methods() |
| 82 | + |
| 83 | + def _wrap_client_methods(self): |
| 84 | + original_get_req = self._client.get_req |
| 85 | + original_get_req_stream = self._client.get_req_stream |
| 86 | + |
| 87 | + def wrapped_get_req(payload: dict = {}, extra_headers: dict = {}): |
| 88 | + merged_headers = {**extra_headers, **self.custom_headers} |
| 89 | + return original_get_req(payload=payload, extra_headers=merged_headers) |
| 90 | + |
| 91 | + def wrapped_get_req_stream(payload: dict, extra_headers: dict = {}): |
| 92 | + merged_headers = {**extra_headers, **self.custom_headers} |
| 93 | + return original_get_req_stream( |
| 94 | + payload=payload, extra_headers=merged_headers |
| 95 | + ) |
| 96 | + |
| 97 | + object.__setattr__(self._client, "get_req", wrapped_get_req) |
| 98 | + object.__setattr__(self._client, "get_req_stream", wrapped_get_req_stream) |
58 | 99 |
|
59 | 100 | @stream_decorator |
60 | 101 | def _generate( |
|
0 commit comments