Skip to content

Commit 85c4a12

Browse files
committed
Initial checkin
1 parent 71d00f0 commit 85c4a12

File tree

2 files changed

+79
-36
lines changed

2 files changed

+79
-36
lines changed

nemoguardrails/server/api.py

Lines changed: 70 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import time
2323
import warnings
2424
from contextlib import asynccontextmanager
25-
from typing import Any, List, Optional
25+
from typing import Any, Callable, List, Optional
2626

2727
from fastapi import FastAPI, Request
2828
from fastapi.middleware.cors import CORSMiddleware
@@ -42,14 +42,32 @@
4242
logging.basicConfig(level=logging.INFO)
4343
log = logging.getLogger(__name__)
4444

45+
46+
class GuardrailsApp(FastAPI):
47+
"""Custom FastAPI subclass with additional attributes for Guardrails server."""
48+
49+
def __init__(self, *args, **kwargs):
50+
super().__init__(*args, **kwargs)
51+
# Initialize custom attributes
52+
self.default_config_id: Optional[str] = None
53+
self.rails_config_path: str = ""
54+
self.disable_chat_ui: bool = False
55+
self.auto_reload: bool = False
56+
self.stop_signal: bool = False
57+
self.single_config_mode: bool = False
58+
self.single_config_id: Optional[str] = None
59+
self.loop: Optional[asyncio.AbstractEventLoop] = None
60+
self.task: Optional[asyncio.Future] = None
61+
62+
4563
# The list of registered loggers. Can be used to send logs to various
4664
# backends and storage engines.
47-
registered_loggers = []
65+
registered_loggers: List[Callable] = []
4866

4967
api_description = """Guardrails Sever API."""
5068

5169
# The headers for each request
52-
api_request_headers = contextvars.ContextVar("headers")
70+
api_request_headers: contextvars.ContextVar = contextvars.ContextVar("headers")
5371

5472
# The datastore that the Server should use.
5573
# This is currently used only for storing threads.
@@ -59,7 +77,7 @@
5977

6078

6179
@asynccontextmanager
62-
async def lifespan(app: FastAPI):
80+
async def lifespan(app: GuardrailsApp):
6381
# Startup logic here
6482
"""Register any additional challenges, if available at startup."""
6583
challenges_files = os.path.join(app.rails_config_path, "challenges.json")
@@ -82,8 +100,11 @@ async def lifespan(app: FastAPI):
82100
if os.path.exists(filepath):
83101
filename = os.path.basename(filepath)
84102
spec = importlib.util.spec_from_file_location(filename, filepath)
85-
config_module = importlib.util.module_from_spec(spec)
86-
spec.loader.exec_module(config_module)
103+
if spec is not None and spec.loader is not None:
104+
config_module = importlib.util.module_from_spec(spec)
105+
spec.loader.exec_module(config_module)
106+
else:
107+
config_module = None
87108

88109
# If there is an `init` function, we call it with the reference to the app.
89110
if config_module is not None and hasattr(config_module, "init"):
@@ -110,21 +131,22 @@ async def root_handler():
110131

111132
if app.auto_reload:
112133
app.loop = asyncio.get_running_loop()
134+
# Store the future directly as task
113135
app.task = app.loop.run_in_executor(None, start_auto_reload_monitoring)
114136

115137
yield
116138

117139
# Shutdown logic here
118140
if app.auto_reload:
119141
app.stop_signal = True
120-
if hasattr(app, "task"):
142+
if hasattr(app, "task") and app.task is not None:
121143
app.task.cancel()
122144
log.info("Shutting down file observer")
123145
else:
124146
pass
125147

126148

127-
app = FastAPI(
149+
app = GuardrailsApp(
128150
title="Guardrails Server API",
129151
description=api_description,
130152
version="0.1.0",
@@ -186,7 +208,7 @@ class RequestBody(BaseModel):
186208
max_length=255,
187209
description="The id of an existing thread to which the messages should be added.",
188210
)
189-
messages: List[dict] = Field(
211+
messages: Optional[List[dict]] = Field(
190212
default=None, description="The list of messages in the current conversation."
191213
)
192214
context: Optional[dict] = Field(
@@ -232,7 +254,7 @@ def ensure_config_ids(cls, v, values):
232254

233255

234256
class ResponseBody(BaseModel):
235-
messages: List[dict] = Field(
257+
messages: Optional[List[dict]] = Field(
236258
default=None, description="The new messages in the conversation"
237259
)
238260
llm_output: Optional[dict] = Field(
@@ -282,8 +304,8 @@ async def get_rails_configs():
282304

283305

284306
# One instance of LLMRails per config id
285-
llm_rails_instances = {}
286-
llm_rails_events_history_cache = {}
307+
llm_rails_instances: dict[str, LLMRails] = {}
308+
llm_rails_events_history_cache: dict[str, dict] = {}
287309

288310

289311
def _generate_cache_key(config_ids: List[str]) -> str:
@@ -310,7 +332,7 @@ def _get_rails(config_ids: List[str]) -> LLMRails:
310332
# get the same thing.
311333
config_ids = [""]
312334

313-
full_llm_rails_config = None
335+
full_llm_rails_config: Optional[RailsConfig] = None
314336

315337
for config_id in config_ids:
316338
base_path = os.path.abspath(app.rails_config_path)
@@ -330,6 +352,9 @@ def _get_rails(config_ids: List[str]) -> LLMRails:
330352
else:
331353
full_llm_rails_config += rails_config
332354

355+
if full_llm_rails_config is None:
356+
raise ValueError("No valid rails configuration found.")
357+
333358
llm_rails = LLMRails(config=full_llm_rails_config, verbose=True)
334359
llm_rails_instances[configs_cache_key] = llm_rails
335360

@@ -368,22 +393,27 @@ async def chat_completion(body: RequestBody, request: Request):
368393
"No 'config_id' provided and no default configuration is set for the server. "
369394
"You must set a 'config_id' in your request or set use --default-config-id when . "
370395
)
396+
397+
# Ensure config_ids is not None before passing to _get_rails
398+
if config_ids is None:
399+
raise GuardrailsConfigurationError("No valid configuration IDs available.")
400+
371401
try:
372402
llm_rails = _get_rails(config_ids)
373403
except ValueError as ex:
374404
log.exception(ex)
375-
return {
376-
"messages": [
405+
return ResponseBody(
406+
messages=[
377407
{
378408
"role": "assistant",
379409
"content": f"Could not load the {config_ids} guardrails configuration. "
380410
f"An internal error has occurred.",
381411
}
382412
]
383-
}
413+
)
384414

385415
try:
386-
messages = body.messages
416+
messages = body.messages or []
387417
if body.context:
388418
messages.insert(0, {"role": "context", "content": body.context})
389419

@@ -396,14 +426,14 @@ async def chat_completion(body: RequestBody, request: Request):
396426

397427
# We make sure the `thread_id` meets the minimum complexity requirement.
398428
if len(body.thread_id) < 16:
399-
return {
400-
"messages": [
429+
return ResponseBody(
430+
messages=[
401431
{
402432
"role": "assistant",
403433
"content": "The `thread_id` must have a minimum length of 16 characters.",
404434
}
405435
]
406-
}
436+
)
407437

408438
# Fetch the existing thread messages. For easier management, we prepend
409439
# the string `thread-` to all thread keys.
@@ -440,32 +470,37 @@ async def chat_completion(body: RequestBody, request: Request):
440470
)
441471

442472
if isinstance(res, GenerationResponse):
443-
bot_message = res.response[0]
473+
bot_message_content = res.response[0]
474+
# Ensure bot_message is always a dict
475+
if isinstance(bot_message_content, str):
476+
bot_message = {"role": "assistant", "content": bot_message_content}
477+
else:
478+
bot_message = bot_message_content
444479
else:
445480
assert isinstance(res, dict)
446481
bot_message = res
447482

448483
# If we're using threads, we also need to update the data before returning
449484
# the message.
450-
if body.thread_id:
485+
if body.thread_id and datastore is not None and datastore_key is not None:
451486
await datastore.set(datastore_key, json.dumps(messages + [bot_message]))
452487

453-
result = {"messages": [bot_message]}
488+
result = ResponseBody(messages=[bot_message])
454489

455490
# If we have additional GenerationResponse fields, we return as well
456491
if isinstance(res, GenerationResponse):
457-
result["llm_output"] = res.llm_output
458-
result["output_data"] = res.output_data
459-
result["log"] = res.log
460-
result["state"] = res.state
492+
result.llm_output = res.llm_output
493+
result.output_data = res.output_data
494+
result.log = res.log
495+
result.state = res.state
461496

462497
return result
463498

464499
except Exception as ex:
465500
log.exception(ex)
466-
return {
467-
"messages": [{"role": "assistant", "content": "Internal server error."}]
468-
}
501+
return ResponseBody(
502+
messages=[{"role": "assistant", "content": "Internal server error."}]
503+
)
469504

470505

471506
# By default, there are no challenges
@@ -498,7 +533,7 @@ def register_datastore(datastore_instance: DataStore):
498533
datastore = datastore_instance
499534

500535

501-
def register_logger(logger: callable):
536+
def register_logger(logger: Callable):
502537
"""Register an additional logger"""
503538
registered_loggers.append(logger)
504539

@@ -510,8 +545,7 @@ def start_auto_reload_monitoring():
510545
from watchdog.observers import Observer
511546

512547
class Handler(FileSystemEventHandler):
513-
@staticmethod
514-
def on_any_event(event):
548+
def on_any_event(self, event):
515549
if event.is_directory:
516550
return None
517551

@@ -521,7 +555,8 @@ def on_any_event(event):
521555
)
522556

523557
# Compute the relative path
524-
rel_path = os.path.relpath(event.src_path, app.rails_config_path)
558+
src_path_str = str(event.src_path)
559+
rel_path = os.path.relpath(src_path_str, app.rails_config_path)
525560

526561
# The config_id is the first component
527562
parts = rel_path.split(os.path.sep)
@@ -530,7 +565,7 @@ def on_any_event(event):
530565
if (
531566
not parts[-1].startswith(".")
532567
and ".ipynb_checkpoints" not in parts
533-
and os.path.isfile(event.src_path)
568+
and os.path.isfile(src_path_str)
534569
):
535570
# We just remove the config from the cache so that a new one is used next time
536571
if config_id in llm_rails_instances:

nemoguardrails/server/datastore/redis_store.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616
import asyncio
1717
from typing import Optional
1818

19-
import aioredis
19+
try:
20+
import aioredis # type: ignore[import]
21+
except ImportError:
22+
aioredis = None # type: ignore[assignment]
2023

2124
from nemoguardrails.server.datastore.datastore import DataStore
2225

@@ -35,6 +38,11 @@ def __init__(
3538
username: [Optional] The username to use for authentication.
3639
password: [Optional] The password to use for authentication
3740
"""
41+
if aioredis is None:
42+
raise ImportError(
43+
"aioredis is required for RedisStore. Install it with: pip install aioredis"
44+
)
45+
3846
self.url = url
3947
self.username = username
4048
self.password = password

0 commit comments

Comments
 (0)