2222import  time 
2323import  warnings 
2424from  contextlib  import  asynccontextmanager 
25- from  typing  import  Any , List , Optional 
25+ from  typing  import  Any , Callable ,  List , Optional 
2626
2727from  fastapi  import  FastAPI , Request 
2828from  fastapi .middleware .cors  import  CORSMiddleware 
4242logging .basicConfig (level = logging .INFO )
4343log  =  logging .getLogger (__name__ )
4444
45+ 
46+ class  GuardrailsApp (FastAPI ):
47+     """Custom FastAPI subclass with additional attributes for Guardrails server.""" 
48+ 
49+     def  __init__ (self , * args , ** kwargs ):
50+         super ().__init__ (* args , ** kwargs )
51+         # Initialize custom attributes 
52+         self .default_config_id : Optional [str ] =  None 
53+         self .rails_config_path : str  =  "" 
54+         self .disable_chat_ui : bool  =  False 
55+         self .auto_reload : bool  =  False 
56+         self .stop_signal : bool  =  False 
57+         self .single_config_mode : bool  =  False 
58+         self .single_config_id : Optional [str ] =  None 
59+         self .loop : Optional [asyncio .AbstractEventLoop ] =  None 
60+         self .task : Optional [asyncio .Future ] =  None 
61+ 
62+ 
4563# The list of registered loggers. Can be used to send logs to various 
4664# backends and storage engines. 
47- registered_loggers  =  []
65+ registered_loggers :  List [ Callable ]  =  []
4866
4967api_description  =  """Guardrails Sever API.""" 
5068
5169# The headers for each request 
52- api_request_headers  =  contextvars .ContextVar ("headers" )
70+ api_request_headers :  contextvars . ContextVar  =  contextvars .ContextVar ("headers" )
5371
5472# The datastore that the Server should use. 
5573# This is currently used only for storing threads. 
5977
6078
6179@asynccontextmanager  
62- async  def  lifespan (app : FastAPI ):
80+ async  def  lifespan (app : GuardrailsApp ):
6381    # Startup logic here 
6482    """Register any additional challenges, if available at startup.""" 
6583    challenges_files  =  os .path .join (app .rails_config_path , "challenges.json" )
@@ -82,8 +100,11 @@ async def lifespan(app: FastAPI):
82100        if  os .path .exists (filepath ):
83101            filename  =  os .path .basename (filepath )
84102            spec  =  importlib .util .spec_from_file_location (filename , filepath )
85-             config_module  =  importlib .util .module_from_spec (spec )
86-             spec .loader .exec_module (config_module )
103+             if  spec  is  not None  and  spec .loader  is  not None :
104+                 config_module  =  importlib .util .module_from_spec (spec )
105+                 spec .loader .exec_module (config_module )
106+             else :
107+                 config_module  =  None 
87108
88109            # If there is an `init` function, we call it with the reference to the app. 
89110            if  config_module  is  not None  and  hasattr (config_module , "init" ):
@@ -110,21 +131,22 @@ async def root_handler():
110131
111132    if  app .auto_reload :
112133        app .loop  =  asyncio .get_running_loop ()
134+         # Store the future directly as task 
113135        app .task  =  app .loop .run_in_executor (None , start_auto_reload_monitoring )
114136
115137    yield 
116138
117139    # Shutdown logic here 
118140    if  app .auto_reload :
119141        app .stop_signal  =  True 
120-         if  hasattr (app , "task" ):
142+         if  hasattr (app , "task" )  and   app . task   is   not   None :
121143            app .task .cancel ()
122144        log .info ("Shutting down file observer" )
123145    else :
124146        pass 
125147
126148
127- app  =  FastAPI (
149+ app  =  GuardrailsApp (
128150    title = "Guardrails Server API" ,
129151    description = api_description ,
130152    version = "0.1.0" ,
@@ -186,7 +208,7 @@ class RequestBody(BaseModel):
186208        max_length = 255 ,
187209        description = "The id of an existing thread to which the messages should be added." ,
188210    )
189-     messages : List [dict ] =  Field (
211+     messages : Optional [ List [dict ] ] =  Field (
190212        default = None , description = "The list of messages in the current conversation." 
191213    )
192214    context : Optional [dict ] =  Field (
@@ -232,7 +254,7 @@ def ensure_config_ids(cls, v, values):
232254
233255
234256class  ResponseBody (BaseModel ):
235-     messages : List [dict ] =  Field (
257+     messages : Optional [ List [dict ] ] =  Field (
236258        default = None , description = "The new messages in the conversation" 
237259    )
238260    llm_output : Optional [dict ] =  Field (
@@ -282,8 +304,8 @@ async def get_rails_configs():
282304
283305
284306# One instance of LLMRails per config id 
285- llm_rails_instances  =  {}
286- llm_rails_events_history_cache  =  {}
307+ llm_rails_instances :  dict [ str ,  LLMRails ]  =  {}
308+ llm_rails_events_history_cache :  dict [ str ,  dict ]  =  {}
287309
288310
289311def  _generate_cache_key (config_ids : List [str ]) ->  str :
@@ -310,7 +332,7 @@ def _get_rails(config_ids: List[str]) -> LLMRails:
310332        # get the same thing. 
311333        config_ids  =  ["" ]
312334
313-     full_llm_rails_config  =  None 
335+     full_llm_rails_config :  Optional [ RailsConfig ]  =  None 
314336
315337    for  config_id  in  config_ids :
316338        base_path  =  os .path .abspath (app .rails_config_path )
@@ -330,6 +352,9 @@ def _get_rails(config_ids: List[str]) -> LLMRails:
330352        else :
331353            full_llm_rails_config  +=  rails_config 
332354
355+     if  full_llm_rails_config  is  None :
356+         raise  ValueError ("No valid rails configuration found." )
357+ 
333358    llm_rails  =  LLMRails (config = full_llm_rails_config , verbose = True )
334359    llm_rails_instances [configs_cache_key ] =  llm_rails 
335360
@@ -360,30 +385,33 @@ async def chat_completion(body: RequestBody, request: Request):
360385    # Save the request headers in a context variable. 
361386    api_request_headers .set (request .headers )
362387
388+     # Use Request config_ids if set, otherwise use the FastAPI default config. 
389+     # If neither is available we can't generate any completions as we have no config_id 
363390    config_ids  =  body .config_ids 
364-     if  not  config_ids  and  app .default_config_id :
365-         config_ids  =  [app .default_config_id ]
366-     elif  not  config_ids  and  not  app .default_config_id :
367-         raise  GuardrailsConfigurationError (
368-             "No 'config_id' provided and no default configuration is set for the server. " 
369-             "You must set a 'config_id' in your request or set use --default-config-id when . " 
370-         )
391+     if  not  config_ids :
392+         if  app .default_config_id :
393+             config_ids  =  [app .default_config_id ]
394+         else :
395+             raise  GuardrailsConfigurationError (
396+                 "No request config_ids provided and server has no default configuration" 
397+             )
398+ 
371399    try :
372400        llm_rails  =  _get_rails (config_ids )
373401    except  ValueError  as  ex :
374402        log .exception (ex )
375-         return  { 
376-             " messages" :  [
403+         return  ResponseBody ( 
404+             messages = [
377405                {
378406                    "role" : "assistant" ,
379407                    "content" : f"Could not load the { config_ids }  
380408                    f"An internal error has occurred." ,
381409                }
382410            ]
383-         } 
411+         ) 
384412
385413    try :
386-         messages  =  body .messages 
414+         messages  =  body .messages   or  [] 
387415        if  body .context :
388416            messages .insert (0 , {"role" : "context" , "content" : body .context })
389417
@@ -396,14 +424,14 @@ async def chat_completion(body: RequestBody, request: Request):
396424
397425            # We make sure the `thread_id` meets the minimum complexity requirement. 
398426            if  len (body .thread_id ) <  16 :
399-                 return  { 
400-                     " messages" :  [
427+                 return  ResponseBody ( 
428+                     messages = [
401429                        {
402430                            "role" : "assistant" ,
403431                            "content" : "The `thread_id` must have a minimum length of 16 characters." ,
404432                        }
405433                    ]
406-                 } 
434+                 ) 
407435
408436            # Fetch the existing thread messages. For easier management, we prepend 
409437            # the string `thread-` to all thread keys. 
@@ -440,32 +468,37 @@ async def chat_completion(body: RequestBody, request: Request):
440468            )
441469
442470            if  isinstance (res , GenerationResponse ):
443-                 bot_message  =  res .response [0 ]
471+                 bot_message_content  =  res .response [0 ]
472+                 # Ensure bot_message is always a dict 
473+                 if  isinstance (bot_message_content , str ):
474+                     bot_message  =  {"role" : "assistant" , "content" : bot_message_content }
475+                 else :
476+                     bot_message  =  bot_message_content 
444477            else :
445478                assert  isinstance (res , dict )
446479                bot_message  =  res 
447480
448481            # If we're using threads, we also need to update the data before returning 
449482            # the message. 
450-             if  body .thread_id :
483+             if  body .thread_id   and   datastore   is   not   None   and   datastore_key   is   not   None :
451484                await  datastore .set (datastore_key , json .dumps (messages  +  [bot_message ]))
452485
453-             result  =  { " messages" :  [bot_message ]} 
486+             result  =  ResponseBody ( messages = [bot_message ]) 
454487
455488            # If we have additional GenerationResponse fields, we return as well 
456489            if  isinstance (res , GenerationResponse ):
457-                 result [ " llm_output" ]  =  res .llm_output 
458-                 result [ " output_data" ]  =  res .output_data 
459-                 result [ " log" ]  =  res .log 
460-                 result [ " state" ]  =  res .state 
490+                 result . llm_output  =  res .llm_output 
491+                 result . output_data  =  res .output_data 
492+                 result . log  =  res .log 
493+                 result . state  =  res .state 
461494
462495            return  result 
463496
464497    except  Exception  as  ex :
465498        log .exception (ex )
466-         return  { 
467-             " messages" :  [{"role" : "assistant" , "content" : "Internal server error." }]
468-         } 
499+         return  ResponseBody ( 
500+             messages = [{"role" : "assistant" , "content" : "Internal server error." }]
501+         ) 
469502
470503
471504# By default, there are no challenges 
@@ -498,7 +531,7 @@ def register_datastore(datastore_instance: DataStore):
498531    datastore  =  datastore_instance 
499532
500533
501- def  register_logger (logger : callable ):
534+ def  register_logger (logger : Callable ):
502535    """Register an additional logger""" 
503536    registered_loggers .append (logger )
504537
@@ -510,8 +543,7 @@ def start_auto_reload_monitoring():
510543        from  watchdog .observers  import  Observer 
511544
512545        class  Handler (FileSystemEventHandler ):
513-             @staticmethod  
514-             def  on_any_event (event ):
546+             def  on_any_event (self , event ):
515547                if  event .is_directory :
516548                    return  None 
517549
@@ -521,7 +553,8 @@ def on_any_event(event):
521553                    )
522554
523555                    # Compute the relative path 
524-                     rel_path  =  os .path .relpath (event .src_path , app .rails_config_path )
556+                     src_path_str  =  str (event .src_path )
557+                     rel_path  =  os .path .relpath (src_path_str , app .rails_config_path )
525558
526559                    # The config_id is the first component 
527560                    parts  =  rel_path .split (os .path .sep )
@@ -530,7 +563,7 @@ def on_any_event(event):
530563                    if  (
531564                        not  parts [- 1 ].startswith ("." )
532565                        and  ".ipynb_checkpoints"  not  in parts 
533-                         and  os .path .isfile (event . src_path )
566+                         and  os .path .isfile (src_path_str )
534567                    ):
535568                        # We just remove the config from the cache so that a new one is used next time 
536569                        if  config_id  in  llm_rails_instances :
0 commit comments