2222import time
2323import warnings
2424from contextlib import asynccontextmanager
25- from typing import Any , List , Optional
25+ from typing import Any , Callable , List , Optional
2626
2727from fastapi import FastAPI , Request
2828from fastapi .middleware .cors import CORSMiddleware
4242logging .basicConfig (level = logging .INFO )
4343log = logging .getLogger (__name__ )
4444
45+
46+ class GuardrailsApp (FastAPI ):
47+ """Custom FastAPI subclass with additional attributes for Guardrails server."""
48+
49+ def __init__ (self , * args , ** kwargs ):
50+ super ().__init__ (* args , ** kwargs )
51+ # Initialize custom attributes
52+ self .default_config_id : Optional [str ] = None
53+ self .rails_config_path : str = ""
54+ self .disable_chat_ui : bool = False
55+ self .auto_reload : bool = False
56+ self .stop_signal : bool = False
57+ self .single_config_mode : bool = False
58+ self .single_config_id : Optional [str ] = None
59+ self .loop : Optional [asyncio .AbstractEventLoop ] = None
60+ self .task : Optional [asyncio .Future ] = None
61+
62+
4563# The list of registered loggers. Can be used to send logs to various
4664# backends and storage engines.
47- registered_loggers = []
65+ registered_loggers : List [ Callable ] = []
4866
4967api_description = """Guardrails Sever API."""
5068
5169# The headers for each request
52- api_request_headers = contextvars .ContextVar ("headers" )
70+ api_request_headers : contextvars . ContextVar = contextvars .ContextVar ("headers" )
5371
5472# The datastore that the Server should use.
5573# This is currently used only for storing threads.
5977
6078
6179@asynccontextmanager
62- async def lifespan (app : FastAPI ):
80+ async def lifespan (app : GuardrailsApp ):
6381 # Startup logic here
6482 """Register any additional challenges, if available at startup."""
6583 challenges_files = os .path .join (app .rails_config_path , "challenges.json" )
@@ -82,8 +100,11 @@ async def lifespan(app: FastAPI):
82100 if os .path .exists (filepath ):
83101 filename = os .path .basename (filepath )
84102 spec = importlib .util .spec_from_file_location (filename , filepath )
85- config_module = importlib .util .module_from_spec (spec )
86- spec .loader .exec_module (config_module )
103+ if spec is not None and spec .loader is not None :
104+ config_module = importlib .util .module_from_spec (spec )
105+ spec .loader .exec_module (config_module )
106+ else :
107+ config_module = None
87108
88109 # If there is an `init` function, we call it with the reference to the app.
89110 if config_module is not None and hasattr (config_module , "init" ):
@@ -110,21 +131,22 @@ async def root_handler():
110131
111132 if app .auto_reload :
112133 app .loop = asyncio .get_running_loop ()
134+ # Store the future directly as task
113135 app .task = app .loop .run_in_executor (None , start_auto_reload_monitoring )
114136
115137 yield
116138
117139 # Shutdown logic here
118140 if app .auto_reload :
119141 app .stop_signal = True
120- if hasattr (app , "task" ):
142+ if hasattr (app , "task" ) and app . task is not None :
121143 app .task .cancel ()
122144 log .info ("Shutting down file observer" )
123145 else :
124146 pass
125147
126148
127- app = FastAPI (
149+ app = GuardrailsApp (
128150 title = "Guardrails Server API" ,
129151 description = api_description ,
130152 version = "0.1.0" ,
@@ -186,7 +208,7 @@ class RequestBody(BaseModel):
186208 max_length = 255 ,
187209 description = "The id of an existing thread to which the messages should be added." ,
188210 )
189- messages : List [dict ] = Field (
211+ messages : Optional [ List [dict ] ] = Field (
190212 default = None , description = "The list of messages in the current conversation."
191213 )
192214 context : Optional [dict ] = Field (
@@ -232,7 +254,7 @@ def ensure_config_ids(cls, v, values):
232254
233255
234256class ResponseBody (BaseModel ):
235- messages : List [dict ] = Field (
257+ messages : Optional [ List [dict ] ] = Field (
236258 default = None , description = "The new messages in the conversation"
237259 )
238260 llm_output : Optional [dict ] = Field (
@@ -282,8 +304,8 @@ async def get_rails_configs():
282304
283305
284306# One instance of LLMRails per config id
285- llm_rails_instances = {}
286- llm_rails_events_history_cache = {}
307+ llm_rails_instances : dict [ str , LLMRails ] = {}
308+ llm_rails_events_history_cache : dict [ str , dict ] = {}
287309
288310
289311def _generate_cache_key (config_ids : List [str ]) -> str :
@@ -310,7 +332,7 @@ def _get_rails(config_ids: List[str]) -> LLMRails:
310332 # get the same thing.
311333 config_ids = ["" ]
312334
313- full_llm_rails_config = None
335+ full_llm_rails_config : Optional [ RailsConfig ] = None
314336
315337 for config_id in config_ids :
316338 base_path = os .path .abspath (app .rails_config_path )
@@ -330,6 +352,9 @@ def _get_rails(config_ids: List[str]) -> LLMRails:
330352 else :
331353 full_llm_rails_config += rails_config
332354
355+ if full_llm_rails_config is None :
356+ raise ValueError ("No valid rails configuration found." )
357+
333358 llm_rails = LLMRails (config = full_llm_rails_config , verbose = True )
334359 llm_rails_instances [configs_cache_key ] = llm_rails
335360
@@ -368,22 +393,27 @@ async def chat_completion(body: RequestBody, request: Request):
368393 "No 'config_id' provided and no default configuration is set for the server. "
369394 "You must set a 'config_id' in your request or set use --default-config-id when . "
370395 )
396+
397+ # Ensure config_ids is not None before passing to _get_rails
398+ if config_ids is None :
399+ raise GuardrailsConfigurationError ("No valid configuration IDs available." )
400+
371401 try :
372402 llm_rails = _get_rails (config_ids )
373403 except ValueError as ex :
374404 log .exception (ex )
375- return {
376- " messages" : [
405+ return ResponseBody (
406+ messages = [
377407 {
378408 "role" : "assistant" ,
379409 "content" : f"Could not load the { config_ids } guardrails configuration. "
380410 f"An internal error has occurred." ,
381411 }
382412 ]
383- }
413+ )
384414
385415 try :
386- messages = body .messages
416+ messages = body .messages or []
387417 if body .context :
388418 messages .insert (0 , {"role" : "context" , "content" : body .context })
389419
@@ -396,14 +426,14 @@ async def chat_completion(body: RequestBody, request: Request):
396426
397427 # We make sure the `thread_id` meets the minimum complexity requirement.
398428 if len (body .thread_id ) < 16 :
399- return {
400- " messages" : [
429+ return ResponseBody (
430+ messages = [
401431 {
402432 "role" : "assistant" ,
403433 "content" : "The `thread_id` must have a minimum length of 16 characters." ,
404434 }
405435 ]
406- }
436+ )
407437
408438 # Fetch the existing thread messages. For easier management, we prepend
409439 # the string `thread-` to all thread keys.
@@ -440,32 +470,37 @@ async def chat_completion(body: RequestBody, request: Request):
440470 )
441471
442472 if isinstance (res , GenerationResponse ):
443- bot_message = res .response [0 ]
473+ bot_message_content = res .response [0 ]
474+ # Ensure bot_message is always a dict
475+ if isinstance (bot_message_content , str ):
476+ bot_message = {"role" : "assistant" , "content" : bot_message_content }
477+ else :
478+ bot_message = bot_message_content
444479 else :
445480 assert isinstance (res , dict )
446481 bot_message = res
447482
448483 # If we're using threads, we also need to update the data before returning
449484 # the message.
450- if body .thread_id :
485+ if body .thread_id and datastore is not None and datastore_key is not None :
451486 await datastore .set (datastore_key , json .dumps (messages + [bot_message ]))
452487
453- result = { " messages" : [bot_message ]}
488+ result = ResponseBody ( messages = [bot_message ])
454489
455490 # If we have additional GenerationResponse fields, we return as well
456491 if isinstance (res , GenerationResponse ):
457- result [ " llm_output" ] = res .llm_output
458- result [ " output_data" ] = res .output_data
459- result [ " log" ] = res .log
460- result [ " state" ] = res .state
492+ result . llm_output = res .llm_output
493+ result . output_data = res .output_data
494+ result . log = res .log
495+ result . state = res .state
461496
462497 return result
463498
464499 except Exception as ex :
465500 log .exception (ex )
466- return {
467- " messages" : [{"role" : "assistant" , "content" : "Internal server error." }]
468- }
501+ return ResponseBody (
502+ messages = [{"role" : "assistant" , "content" : "Internal server error." }]
503+ )
469504
470505
471506# By default, there are no challenges
@@ -498,7 +533,7 @@ def register_datastore(datastore_instance: DataStore):
498533 datastore = datastore_instance
499534
500535
501- def register_logger (logger : callable ):
536+ def register_logger (logger : Callable ):
502537 """Register an additional logger"""
503538 registered_loggers .append (logger )
504539
@@ -510,8 +545,7 @@ def start_auto_reload_monitoring():
510545 from watchdog .observers import Observer
511546
512547 class Handler (FileSystemEventHandler ):
513- @staticmethod
514- def on_any_event (event ):
548+ def on_any_event (self , event ):
515549 if event .is_directory :
516550 return None
517551
@@ -521,7 +555,8 @@ def on_any_event(event):
521555 )
522556
523557 # Compute the relative path
524- rel_path = os .path .relpath (event .src_path , app .rails_config_path )
558+ src_path_str = str (event .src_path )
559+ rel_path = os .path .relpath (src_path_str , app .rails_config_path )
525560
526561 # The config_id is the first component
527562 parts = rel_path .split (os .path .sep )
@@ -530,7 +565,7 @@ def on_any_event(event):
530565 if (
531566 not parts [- 1 ].startswith ("." )
532567 and ".ipynb_checkpoints" not in parts
533- and os .path .isfile (event . src_path )
568+ and os .path .isfile (src_path_str )
534569 ):
535570 # We just remove the config from the cache so that a new one is used next time
536571 if config_id in llm_rails_instances :
0 commit comments