Skip to content

Commit 9259ab9

Browse files
committed
feat(llm): Add custom HTTP headers support to ChatNVIDIA provider
Add custom HTTP headers support to the ChatNVIDIA class patch, enabling users to pass custom headers (authentication tokens, request IDs, billing information, etc.) with all requests to NVIDIA AI endpoints. Implementation Approach - Added custom_headers optional field to ChatNVIDIA class with Pydantic v2 compatibility - Implemented runtime method wrapping that intercepts _client.get_req() and _client.get_req_stream() to merge custom headers with existing headers - Included automatic version detection to ensure compatibility with langchain-nvidia-ai-endpoints >= 0.3.0, with clear error messages for older versions - Works with both synchronous invoke() and streaming requests, fully compatible with VLM (Vision Language Models)
1 parent ce7b866 commit 9259ab9

File tree

2 files changed

+499
-2
lines changed

2 files changed

+499
-2
lines changed

nemoguardrails/llm/providers/_langchain_nvidia_ai_endpoints_patch.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,17 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import inspect
1617
import logging
1718
from functools import wraps
18-
from typing import Any, List, Optional
19+
from typing import Any, Dict, List, Optional
1920

2021
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
2122
from langchain_core.language_models.chat_models import generate_from_stream
2223
from langchain_core.messages import BaseMessage
2324
from langchain_core.outputs import ChatResult
2425
from langchain_nvidia_ai_endpoints import ChatNVIDIA as ChatNVIDIAOriginal
25-
from pydantic.v1 import Field
26+
from pydantic import Field
2627

2728
log = logging.getLogger(__name__)
2829

@@ -55,6 +56,46 @@ class ChatNVIDIA(ChatNVIDIAOriginal):
5556
streaming: bool = Field(
5657
default=False, description="Whether to use streaming or not"
5758
)
59+
custom_headers: Optional[Dict[str, str]] = Field(
60+
default=None, description="Custom HTTP headers to send with requests"
61+
)
62+
63+
def __init__(self, **kwargs: Any):
64+
super().__init__(**kwargs)
65+
if self.custom_headers:
66+
if not hasattr(self._client, "get_req"):
67+
raise RuntimeError(
68+
"custom_headers requires langchain-nvidia-ai-endpoints >= 0.2.1. "
69+
"Your version uses a nested client structure that is not supported. "
70+
"Please upgrade: pip install --upgrade langchain-nvidia-ai-endpoints>=0.3.0"
71+
)
72+
73+
sig = inspect.signature(self._client.get_req)
74+
if "extra_headers" not in sig.parameters:
75+
raise RuntimeError(
76+
"custom_headers requires langchain-nvidia-ai-endpoints >= 0.3.0. "
77+
"Your version does not support the extra_headers parameter. "
78+
"Please upgrade: pip install --upgrade langchain-nvidia-ai-endpoints>=0.3.0"
79+
)
80+
81+
self._wrap_client_methods()
82+
83+
def _wrap_client_methods(self):
84+
original_get_req = self._client.get_req
85+
original_get_req_stream = self._client.get_req_stream
86+
87+
def wrapped_get_req(payload: dict = {}, extra_headers: dict = {}):
88+
merged_headers = {**extra_headers, **self.custom_headers}
89+
return original_get_req(payload=payload, extra_headers=merged_headers)
90+
91+
def wrapped_get_req_stream(payload: dict, extra_headers: dict = {}):
92+
merged_headers = {**extra_headers, **self.custom_headers}
93+
return original_get_req_stream(
94+
payload=payload, extra_headers=merged_headers
95+
)
96+
97+
object.__setattr__(self._client, "get_req", wrapped_get_req)
98+
object.__setattr__(self._client, "get_req_stream", wrapped_get_req_stream)
5899

59100
@stream_decorator
60101
def _generate(

0 commit comments

Comments
 (0)