Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 69 additions & 3 deletions nemoguardrails/actions/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import re
from typing import Any, Dict, List, Optional, Sequence, Union

logger = logging.getLogger(__name__)

from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackManager
from langchain_core.runnables import RunnableConfig
Expand Down Expand Up @@ -238,15 +241,78 @@ def _convert_messages_to_langchain_format(prompt: List[dict]) -> List:


def _store_reasoning_traces(response) -> None:
"""Store reasoning traces from response in context variable.

Extracts reasoning content from response.additional_kwargs["reasoning_content"]
if available. Otherwise, falls back to extracting from <think> tags in the
response content (and removes the tags from content).

Args:
response: The LLM response object
"""

reasoning_content = _extract_reasoning_content(response)

if not reasoning_content:
# Some LLM providers (e.g., certain NVIDIA models) embed reasoning in <think> tags
# instead of properly populating reasoning_content in additional_kwargs, so we need
# both extraction methods to support different provider implementations.
reasoning_content = _extract_and_remove_think_tags(response)

if reasoning_content:
reasoning_trace_var.set(reasoning_content)


def _extract_reasoning_content(response):
if hasattr(response, "additional_kwargs"):
additional_kwargs = response.additional_kwargs
if (
isinstance(additional_kwargs, dict)
and "reasoning_content" in additional_kwargs
):
reasoning_content = additional_kwargs["reasoning_content"]
if reasoning_content:
reasoning_trace_var.set(reasoning_content)
return additional_kwargs["reasoning_content"]
return None


def _extract_and_remove_think_tags(response) -> Optional[str]:
"""Extract reasoning from <think> tags and remove them from `response.content`.

This function looks for <think>...</think> tags in the response content,
and if found, extracts the reasoning content inside the tags. It has a side-effect:
it removes the full reasoning trace and tags from response.content.

Args:
response: The LLM response object

Returns:
The extracted reasoning content, or None if no <think> tags found
"""
if not hasattr(response, "content"):
return None

content = response.content
has_opening_tag = "<think>" in content
has_closing_tag = "</think>" in content

if not has_opening_tag and not has_closing_tag:
return None

if has_opening_tag != has_closing_tag:
logger.warning(
"Malformed <think> tags detected: missing %s tag. "
"Skipping reasoning extraction to prevent corrupted content.",
"closing" if has_opening_tag else "opening",
)
return None

match = re.search(r"<think>(.*?)</think>", content, re.DOTALL)
if match:
reasoning_content = match.group(1).strip()
response.content = re.sub(
r"<think>.*?</think>", "", content, flags=re.DOTALL
).strip()
return reasoning_content
return None


def _store_tool_calls(response) -> None:
Expand Down
87 changes: 45 additions & 42 deletions nemoguardrails/embeddings/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

import asyncio
import logging
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, cast

from annoy import AnnoyIndex
from annoy import AnnoyIndex # type: ignore

from nemoguardrails.embeddings.cache import cache_embeddings
from nemoguardrails.embeddings.index import EmbeddingsIndex, IndexItem
Expand Down Expand Up @@ -45,62 +45,51 @@ class BasicEmbeddingsIndex(EmbeddingsIndex):
max_batch_hold: The maximum time a batch is held before being processed
"""

embedding_model: str
embedding_engine: str
embedding_params: Dict[str, Any]
index: AnnoyIndex
embedding_size: int
cache_config: EmbeddingsCacheConfig
embeddings: List[List[float]]
search_threshold: float
use_batching: bool
max_batch_size: int
max_batch_hold: float

def __init__(
self,
embedding_model=None,
embedding_engine=None,
embedding_params=None,
index=None,
cache_config: Union[EmbeddingsCacheConfig, Dict[str, Any]] = None,
search_threshold: float = None,
embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2",
embedding_engine: str = "SentenceTransformers",
embedding_params: Optional[Dict[str, Any]] = None,
index: Optional[AnnoyIndex] = None,
cache_config: Optional[Union[EmbeddingsCacheConfig, Dict[str, Any]]] = None,
search_threshold: float = float("inf"),
use_batching: bool = False,
max_batch_size: int = 10,
max_batch_hold: float = 0.01,
):
"""Initialize the BasicEmbeddingsIndex.

Args:
embedding_model (str, optional): The model for computing embeddings. Defaults to None.
embedding_engine (str, optional): The engine for computing embeddings. Defaults to None.
index (AnnoyIndex, optional): The pre-existing index. Defaults to None.
cache_config (EmbeddingsCacheConfig | Dict[str, Any], optional): The cache configuration. Defaults to None.
embedding_model: The model for computing embeddings.
embedding_engine: The engine for computing embeddings.
index: The pre-existing index.
cache_config: The cache configuration.
search_threshold: The threshold for filtering search results.
use_batching: Whether to batch requests when computing the embeddings.
max_batch_size: The maximum size of a batch.
max_batch_hold: The maximum time a batch is held before being processed
"""
self._model: Optional[EmbeddingModel] = None
self._items = []
self._embeddings = []
self._items: List[IndexItem] = []
self._embeddings: List[List[float]] = []
self.embedding_model = embedding_model
self.embedding_engine = embedding_engine
self.embedding_params = embedding_params or {}
self._embedding_size = 0
self.search_threshold = search_threshold or float("inf")
self.search_threshold = search_threshold
if isinstance(cache_config, Dict):
self._cache_config = EmbeddingsCacheConfig(**cache_config)
else:
self._cache_config = cache_config or EmbeddingsCacheConfig()
self._index = index

# Data structures for batching embedding requests
self._req_queue = {}
self._req_results = {}
self._req_idx = 0
self._current_batch_finished_event = None
self._current_batch_full_event = None
self._current_batch_submitted = asyncio.Event()
self._req_queue: Dict[int, str] = {}
self._req_results: Dict[int, List[float]] = {}
self._req_idx: int = 0
self._current_batch_finished_event: Optional[asyncio.Event] = None
self._current_batch_full_event: Optional[asyncio.Event] = None
self._current_batch_submitted: asyncio.Event = asyncio.Event()

# Initialize the batching configuration
self.use_batching = use_batching
Expand All @@ -112,6 +101,11 @@ def embeddings_index(self):
"""Get the current embedding index"""
return self._index

@embeddings_index.setter
def embeddings_index(self, index):
"""Setter to allow replacing the index dynamically."""
self._index = index

@property
def cache_config(self):
"""Get the cache configuration."""
Expand All @@ -127,16 +121,14 @@ def embeddings(self):
"""Get the computed embeddings."""
return self._embeddings

@embeddings_index.setter
def embeddings_index(self, index):
"""Setter to allow replacing the index dynamically."""
self._index = index

def _init_model(self):
"""Initialize the model used for computing the embeddings."""
model = self.embedding_model
engine = self.embedding_engine

self._model = init_embedding_model(
embedding_model=self.embedding_model,
embedding_engine=self.embedding_engine,
embedding_model=model,
embedding_engine=engine,
embedding_params=self.embedding_params,
)

Expand All @@ -153,7 +145,9 @@ async def _get_embeddings(self, texts: List[str]) -> List[List[float]]:
if self._model is None:
self._init_model()

embeddings = await self._model.encode_async(texts)
# self._model can't be None here, or self._init_model() would throw a ValueError
model: EmbeddingModel = cast(EmbeddingModel, self._model)
embeddings = await model.encode_async(texts)
return embeddings

async def add_item(self, item: IndexItem):
Expand Down Expand Up @@ -199,6 +193,12 @@ async def _run_batch(self):
"""Runs the current batch of embeddings."""

# Wait up to `max_batch_hold` time or until `max_batch_size` is reached.
if (
self._current_batch_full_event is None
or self._current_batch_finished_event is None
):
raise RuntimeError("Batch events not initialized. This should not happen.")

done, pending = await asyncio.wait(
[
asyncio.create_task(asyncio.sleep(self.max_batch_hold)),
Expand Down Expand Up @@ -244,7 +244,10 @@ async def _batch_get_embeddings(self, text: str) -> List[float]:
self._req_idx += 1
self._req_queue[req_id] = text

if self._current_batch_finished_event is None:
if (
self._current_batch_finished_event is None
or self._current_batch_full_event is None
):
self._current_batch_finished_event = asyncio.Event()
self._current_batch_full_event = asyncio.Event()
self._current_batch_submitted.clear()
Expand Down
38 changes: 26 additions & 12 deletions nemoguardrails/embeddings/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
from abc import ABC, abstractmethod
from functools import singledispatchmethod
from pathlib import Path
from typing import Dict, List
from typing import Dict, List, Optional

try:
import redis # type: ignore
except ImportError:
redis = None # type: ignore

from nemoguardrails.rails.llm.config import EmbeddingsCacheConfig

Expand All @@ -30,6 +35,8 @@
class KeyGenerator(ABC):
"""Abstract class for key generators."""

name: str # Class attribute that should be defined in subclasses

@abstractmethod
def generate_key(self, text: str) -> str:
pass
Expand Down Expand Up @@ -76,6 +83,8 @@ def generate_key(self, text: str) -> str:
class CacheStore(ABC):
"""Abstract class for cache stores."""

name: str

@abstractmethod
def get(self, key):
"""Get a value from the cache."""
Expand Down Expand Up @@ -147,7 +156,7 @@ class FilesystemCacheStore(CacheStore):

name = "filesystem"

def __init__(self, cache_dir: str = None):
def __init__(self, cache_dir: Optional[str] = None):
self._cache_dir = Path(cache_dir or ".cache/embeddings")
self._cache_dir.mkdir(parents=True, exist_ok=True)

Expand Down Expand Up @@ -190,8 +199,10 @@ class RedisCacheStore(CacheStore):
name = "redis"

def __init__(self, host: str = "localhost", port: int = 6379, db: int = 0):
import redis

if redis is None:
raise ImportError(
"Could not import redis, please install it with `pip install redis`."
)
self._redis = redis.Redis(host=host, port=port, db=db)

def get(self, key):
Expand All @@ -207,9 +218,9 @@ def clear(self):
class EmbeddingsCache:
def __init__(
self,
key_generator: KeyGenerator = None,
cache_store: CacheStore = None,
store_config: dict = None,
key_generator: KeyGenerator,
cache_store: CacheStore,
store_config: Optional[dict] = None,
):
self._key_generator = key_generator
self._cache_store = cache_store
Expand All @@ -218,7 +229,10 @@ def __init__(
@classmethod
def from_dict(cls, d: Dict[str, str]):
key_generator = KeyGenerator.from_name(d.get("key_generator"))()
store_config = d.get("store_config")
store_config_raw = d.get("store_config")
store_config: dict = (
store_config_raw if isinstance(store_config_raw, dict) else {}
)
cache_store = CacheStore.from_name(d.get("store"))(**store_config)

return cls(key_generator=key_generator, cache_store=cache_store)
Expand All @@ -239,7 +253,7 @@ def get_config(self):
def get(self, texts):
raise NotImplementedError

@get.register
@get.register(str)
def _(self, text: str):
key = self._key_generator.generate_key(text)
log.info(f"Fetching key {key} for text '{text[:20]}...' from cache")
Expand All @@ -248,7 +262,7 @@ def _(self, text: str):

return result

@get.register
@get.register(list)
def _(self, texts: list):
cached = {}

Expand All @@ -266,13 +280,13 @@ def _(self, texts: list):
def set(self, texts):
raise NotImplementedError

@set.register
@set.register(str)
def _(self, text: str, value: List[float]):
key = self._key_generator.generate_key(text)
log.info(f"Cache miss for text '{text}'. Storing key {key} in cache.")
self._cache_store.set(key, value)

@set.register
@set.register(list)
def _(self, texts: list, values: List[List[float]]):
for text, value in zip(texts, values):
self.set(text, value)
Expand Down
7 changes: 3 additions & 4 deletions nemoguardrails/embeddings/providers/azureopenai.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,16 @@ class AzureEmbeddingModel(EmbeddingModel):

def __init__(self, embedding_model: str):
try:
from openai import AzureOpenAI
from openai import AzureOpenAI # type: ignore
except ImportError:
raise ImportError(
"Could not import openai, please install it with "
"`pip install openai`."
"Could not import openai, please install it with `pip install openai`."
)
# Set Azure OpenAI API credentials
self.client = AzureOpenAI(
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), # type: ignore
)

self.embedding_model = embedding_model
Expand Down
Loading