Skip to content

Commit d2d41f4

Browse files
tgasser-nvPouyanpi
andauthored
chore(types): Type-clean server/ (20 errors) (#1397)
* Initial checkin * Add nemoguardrails/server to pyright type-checking * chore(types): Type-clean embeddings/ (25 errors) (#1383) * test: restore test that was skipped due to Colang 2.0 serialization issue (#1449) * fix(llm): add fallback extraction for reasoning traces from <think> tags (#1474) Adds a compatibility layer for LLM providers that don't properly populate reasoning_content in additional_kwargs. When reasoning_content is missing, the system now falls back to extracting reasoning traces from <think>...</think> tags in the response content and removes the tags from the final output. This fixes compatibility with certain NVIDIA models (e.g., nvidia/llama-3.3-nemotron-super-49b-v1.5) in langchain-nvidia-ai-endpoints that include reasoning traces in <think> tags but fail to populate the reasoning_content field. All reasoning models using ChatNVIDIA should expose reasoning content consistently through the same interface * Clean up the config_id logic based on Traian and Greptile feedback --------- Co-authored-by: Pouyan <[email protected]>
1 parent 735ddac commit d2d41f4

File tree

3 files changed

+85
-43
lines changed

3 files changed

+85
-43
lines changed

nemoguardrails/server/api.py

Lines changed: 75 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import time
2323
import warnings
2424
from contextlib import asynccontextmanager
25-
from typing import Any, List, Optional
25+
from typing import Any, Callable, List, Optional
2626

2727
from fastapi import FastAPI, Request
2828
from fastapi.middleware.cors import CORSMiddleware
@@ -42,14 +42,32 @@
4242
logging.basicConfig(level=logging.INFO)
4343
log = logging.getLogger(__name__)
4444

45+
46+
class GuardrailsApp(FastAPI):
47+
"""Custom FastAPI subclass with additional attributes for Guardrails server."""
48+
49+
def __init__(self, *args, **kwargs):
50+
super().__init__(*args, **kwargs)
51+
# Initialize custom attributes
52+
self.default_config_id: Optional[str] = None
53+
self.rails_config_path: str = ""
54+
self.disable_chat_ui: bool = False
55+
self.auto_reload: bool = False
56+
self.stop_signal: bool = False
57+
self.single_config_mode: bool = False
58+
self.single_config_id: Optional[str] = None
59+
self.loop: Optional[asyncio.AbstractEventLoop] = None
60+
self.task: Optional[asyncio.Future] = None
61+
62+
4563
# The list of registered loggers. Can be used to send logs to various
4664
# backends and storage engines.
47-
registered_loggers = []
65+
registered_loggers: List[Callable] = []
4866

4967
api_description = """Guardrails Sever API."""
5068

5169
# The headers for each request
52-
api_request_headers = contextvars.ContextVar("headers")
70+
api_request_headers: contextvars.ContextVar = contextvars.ContextVar("headers")
5371

5472
# The datastore that the Server should use.
5573
# This is currently used only for storing threads.
@@ -59,7 +77,7 @@
5977

6078

6179
@asynccontextmanager
62-
async def lifespan(app: FastAPI):
80+
async def lifespan(app: GuardrailsApp):
6381
# Startup logic here
6482
"""Register any additional challenges, if available at startup."""
6583
challenges_files = os.path.join(app.rails_config_path, "challenges.json")
@@ -82,8 +100,11 @@ async def lifespan(app: FastAPI):
82100
if os.path.exists(filepath):
83101
filename = os.path.basename(filepath)
84102
spec = importlib.util.spec_from_file_location(filename, filepath)
85-
config_module = importlib.util.module_from_spec(spec)
86-
spec.loader.exec_module(config_module)
103+
if spec is not None and spec.loader is not None:
104+
config_module = importlib.util.module_from_spec(spec)
105+
spec.loader.exec_module(config_module)
106+
else:
107+
config_module = None
87108

88109
# If there is an `init` function, we call it with the reference to the app.
89110
if config_module is not None and hasattr(config_module, "init"):
@@ -110,21 +131,22 @@ async def root_handler():
110131

111132
if app.auto_reload:
112133
app.loop = asyncio.get_running_loop()
134+
# Store the future directly as task
113135
app.task = app.loop.run_in_executor(None, start_auto_reload_monitoring)
114136

115137
yield
116138

117139
# Shutdown logic here
118140
if app.auto_reload:
119141
app.stop_signal = True
120-
if hasattr(app, "task"):
142+
if hasattr(app, "task") and app.task is not None:
121143
app.task.cancel()
122144
log.info("Shutting down file observer")
123145
else:
124146
pass
125147

126148

127-
app = FastAPI(
149+
app = GuardrailsApp(
128150
title="Guardrails Server API",
129151
description=api_description,
130152
version="0.1.0",
@@ -186,7 +208,7 @@ class RequestBody(BaseModel):
186208
max_length=255,
187209
description="The id of an existing thread to which the messages should be added.",
188210
)
189-
messages: List[dict] = Field(
211+
messages: Optional[List[dict]] = Field(
190212
default=None, description="The list of messages in the current conversation."
191213
)
192214
context: Optional[dict] = Field(
@@ -232,7 +254,7 @@ def ensure_config_ids(cls, v, values):
232254

233255

234256
class ResponseBody(BaseModel):
235-
messages: List[dict] = Field(
257+
messages: Optional[List[dict]] = Field(
236258
default=None, description="The new messages in the conversation"
237259
)
238260
llm_output: Optional[dict] = Field(
@@ -282,8 +304,8 @@ async def get_rails_configs():
282304

283305

284306
# One instance of LLMRails per config id
285-
llm_rails_instances = {}
286-
llm_rails_events_history_cache = {}
307+
llm_rails_instances: dict[str, LLMRails] = {}
308+
llm_rails_events_history_cache: dict[str, dict] = {}
287309

288310

289311
def _generate_cache_key(config_ids: List[str]) -> str:
@@ -310,7 +332,7 @@ def _get_rails(config_ids: List[str]) -> LLMRails:
310332
# get the same thing.
311333
config_ids = [""]
312334

313-
full_llm_rails_config = None
335+
full_llm_rails_config: Optional[RailsConfig] = None
314336

315337
for config_id in config_ids:
316338
base_path = os.path.abspath(app.rails_config_path)
@@ -330,6 +352,9 @@ def _get_rails(config_ids: List[str]) -> LLMRails:
330352
else:
331353
full_llm_rails_config += rails_config
332354

355+
if full_llm_rails_config is None:
356+
raise ValueError("No valid rails configuration found.")
357+
333358
llm_rails = LLMRails(config=full_llm_rails_config, verbose=True)
334359
llm_rails_instances[configs_cache_key] = llm_rails
335360

@@ -360,30 +385,33 @@ async def chat_completion(body: RequestBody, request: Request):
360385
# Save the request headers in a context variable.
361386
api_request_headers.set(request.headers)
362387

388+
# Use Request config_ids if set, otherwise use the FastAPI default config.
389+
# If neither is available we can't generate any completions as we have no config_id
363390
config_ids = body.config_ids
364-
if not config_ids and app.default_config_id:
365-
config_ids = [app.default_config_id]
366-
elif not config_ids and not app.default_config_id:
367-
raise GuardrailsConfigurationError(
368-
"No 'config_id' provided and no default configuration is set for the server. "
369-
"You must set a 'config_id' in your request or set use --default-config-id when . "
370-
)
391+
if not config_ids:
392+
if app.default_config_id:
393+
config_ids = [app.default_config_id]
394+
else:
395+
raise GuardrailsConfigurationError(
396+
"No request config_ids provided and server has no default configuration"
397+
)
398+
371399
try:
372400
llm_rails = _get_rails(config_ids)
373401
except ValueError as ex:
374402
log.exception(ex)
375-
return {
376-
"messages": [
403+
return ResponseBody(
404+
messages=[
377405
{
378406
"role": "assistant",
379407
"content": f"Could not load the {config_ids} guardrails configuration. "
380408
f"An internal error has occurred.",
381409
}
382410
]
383-
}
411+
)
384412

385413
try:
386-
messages = body.messages
414+
messages = body.messages or []
387415
if body.context:
388416
messages.insert(0, {"role": "context", "content": body.context})
389417

@@ -396,14 +424,14 @@ async def chat_completion(body: RequestBody, request: Request):
396424

397425
# We make sure the `thread_id` meets the minimum complexity requirement.
398426
if len(body.thread_id) < 16:
399-
return {
400-
"messages": [
427+
return ResponseBody(
428+
messages=[
401429
{
402430
"role": "assistant",
403431
"content": "The `thread_id` must have a minimum length of 16 characters.",
404432
}
405433
]
406-
}
434+
)
407435

408436
# Fetch the existing thread messages. For easier management, we prepend
409437
# the string `thread-` to all thread keys.
@@ -440,32 +468,37 @@ async def chat_completion(body: RequestBody, request: Request):
440468
)
441469

442470
if isinstance(res, GenerationResponse):
443-
bot_message = res.response[0]
471+
bot_message_content = res.response[0]
472+
# Ensure bot_message is always a dict
473+
if isinstance(bot_message_content, str):
474+
bot_message = {"role": "assistant", "content": bot_message_content}
475+
else:
476+
bot_message = bot_message_content
444477
else:
445478
assert isinstance(res, dict)
446479
bot_message = res
447480

448481
# If we're using threads, we also need to update the data before returning
449482
# the message.
450-
if body.thread_id:
483+
if body.thread_id and datastore is not None and datastore_key is not None:
451484
await datastore.set(datastore_key, json.dumps(messages + [bot_message]))
452485

453-
result = {"messages": [bot_message]}
486+
result = ResponseBody(messages=[bot_message])
454487

455488
# If we have additional GenerationResponse fields, we return as well
456489
if isinstance(res, GenerationResponse):
457-
result["llm_output"] = res.llm_output
458-
result["output_data"] = res.output_data
459-
result["log"] = res.log
460-
result["state"] = res.state
490+
result.llm_output = res.llm_output
491+
result.output_data = res.output_data
492+
result.log = res.log
493+
result.state = res.state
461494

462495
return result
463496

464497
except Exception as ex:
465498
log.exception(ex)
466-
return {
467-
"messages": [{"role": "assistant", "content": "Internal server error."}]
468-
}
499+
return ResponseBody(
500+
messages=[{"role": "assistant", "content": "Internal server error."}]
501+
)
469502

470503

471504
# By default, there are no challenges
@@ -498,7 +531,7 @@ def register_datastore(datastore_instance: DataStore):
498531
datastore = datastore_instance
499532

500533

501-
def register_logger(logger: callable):
534+
def register_logger(logger: Callable):
502535
"""Register an additional logger"""
503536
registered_loggers.append(logger)
504537

@@ -510,8 +543,7 @@ def start_auto_reload_monitoring():
510543
from watchdog.observers import Observer
511544

512545
class Handler(FileSystemEventHandler):
513-
@staticmethod
514-
def on_any_event(event):
546+
def on_any_event(self, event):
515547
if event.is_directory:
516548
return None
517549

@@ -521,7 +553,8 @@ def on_any_event(event):
521553
)
522554

523555
# Compute the relative path
524-
rel_path = os.path.relpath(event.src_path, app.rails_config_path)
556+
src_path_str = str(event.src_path)
557+
rel_path = os.path.relpath(src_path_str, app.rails_config_path)
525558

526559
# The config_id is the first component
527560
parts = rel_path.split(os.path.sep)
@@ -530,7 +563,7 @@ def on_any_event(event):
530563
if (
531564
not parts[-1].startswith(".")
532565
and ".ipynb_checkpoints" not in parts
533-
and os.path.isfile(event.src_path)
566+
and os.path.isfile(src_path_str)
534567
):
535568
# We just remove the config from the cache so that a new one is used next time
536569
if config_id in llm_rails_instances:

nemoguardrails/server/datastore/redis_store.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616
import asyncio
1717
from typing import Optional
1818

19-
import aioredis
19+
try:
20+
import aioredis # type: ignore[import]
21+
except ImportError:
22+
aioredis = None # type: ignore[assignment]
2023

2124
from nemoguardrails.server.datastore.datastore import DataStore
2225

@@ -35,6 +38,11 @@ def __init__(
3538
username: [Optional] The username to use for authentication.
3639
password: [Optional] The password to use for authentication
3740
"""
41+
if aioredis is None:
42+
raise ImportError(
43+
"aioredis is required for RedisStore. Install it with: pip install aioredis"
44+
)
45+
3846
self.url = url
3947
self.username = username
4048
self.password = password

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ include = [
162162
"nemoguardrails/kb/**",
163163
"nemoguardrails/logging/**",
164164
"nemoguardrails/tracing/**",
165+
"nemoguardrails/server/**",
165166
"tests/test_callbacks.py",
166167
]
167168

0 commit comments

Comments
 (0)