Skip to content

Commit

Permalink
Merge pull request #367 from NVIDIA/fix/qa-fixes-6
Browse files Browse the repository at this point in the history
Fix LangChain warnings and bug affecting Llama-2 example.
  • Loading branch information
drazvan authored Feb 28, 2024
2 parents 0be7a72 + fedd0ed commit 8bb50af
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
5 changes: 3 additions & 2 deletions nemoguardrails/llm/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,9 @@ def _modify_instance_kwargs(self):
"""

if hasattr(llm_instance, "model_kwargs"):
llm_instance.model_kwargs["temperature"] = self.temperature
llm_instance.model_kwargs["streaming"] = self.streaming
if isinstance(llm_instance.model_kwargs, dict):
llm_instance.model_kwargs["temperature"] = self.temperature
llm_instance.model_kwargs["streaming"] = self.streaming

def _call(
self,
Expand Down
17 changes: 14 additions & 3 deletions nemoguardrails/llm/providers/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _call(
)

# Streaming for NeMo Guardrails is not supported in sync calls.
if self.model_kwargs.get("streaming"):
if self.model_kwargs and self.model_kwargs.get("streaming"):
raise Exception(
"Streaming mode not supported for HuggingFacePipeline in NeMo Guardrails!"
)
Expand Down Expand Up @@ -100,7 +100,7 @@ async def _acall(
)

# Handle streaming, if the flag is set
if self.model_kwargs.get("streaming"):
if self.model_kwargs and self.model_kwargs.get("streaming"):
# Retrieve the streamer object, needs to be set in model_kwargs
streamer = self.model_kwargs.get("streamer")
if not streamer:
Expand Down Expand Up @@ -153,7 +153,18 @@ async def _acall(self, *args, **kwargs):

def discover_langchain_providers():
"""Automatically discover all LLM providers from LangChain."""
_providers.update(llms.type_to_cls_dict)
# To deal with deprecated stuff and avoid warnings, we compose the type_to_cls_dict here
if hasattr(llms, "get_type_to_cls_dict"):
type_to_cls_dict = {
k: v()
for k, v in llms.get_type_to_cls_dict().items()
# Exclude deprecated ones
if k not in ["mlflow-chat", "databricks-chat"]
}
else:
type_to_cls_dict = llms.type_to_cls_dict

_providers.update(type_to_cls_dict)

# We make sure we have OpenAI from the right package.
if "openai" in _providers:
Expand Down

0 comments on commit 8bb50af

Please sign in to comment.