3030
3131import  logging 
3232import  os 
33- from  typing  import  Optional 
33+ from  time  import  time 
34+ from  typing  import  Dict , Optional 
3435
3536from  nemoguardrails .actions  import  action 
37+ from  nemoguardrails .context  import  llm_call_info_var 
3638from  nemoguardrails .library .jailbreak_detection .request  import  (
3739    jailbreak_detection_heuristics_request ,
3840    jailbreak_detection_model_request ,
3941    jailbreak_nim_request ,
4042)
43+ from  nemoguardrails .llm .cache  import  CacheInterface 
44+ from  nemoguardrails .llm .cache .utils  import  (
45+     CacheEntry ,
46+     create_normalized_cache_key ,
47+     get_from_cache_and_restore_stats ,
48+ )
4149from  nemoguardrails .llm .taskmanager  import  LLMTaskManager 
50+ from  nemoguardrails .logging .explain  import  LLMCallInfo 
51+ from  nemoguardrails .logging .processing_log  import  processing_log_var 
4252
4353log  =  logging .getLogger (__name__ )
4454
@@ -89,6 +99,7 @@ async def jailbreak_detection_heuristics(
8999async  def  jailbreak_detection_model (
90100    llm_task_manager : LLMTaskManager ,
91101    context : Optional [dict ] =  None ,
102+     model_caches : Optional [Dict [str , CacheInterface ]] =  None ,
92103) ->  bool :
93104    """Uses a trained classifier to determine if a user input is a jailbreak attempt""" 
94105    prompt : str  =  "" 
@@ -102,6 +113,30 @@ async def jailbreak_detection_model(
102113    if  context  is  not None :
103114        prompt  =  context .get ("user_message" , "" )
104115
116+     # we do this as a hack to treat this action as an LLM call for tracing 
117+     llm_call_info_var .set (LLMCallInfo (task = "jailbreak_detection_model" ))
118+ 
119+     cache  =  model_caches .get ("jailbreak_detection" ) if  model_caches  else  None 
120+ 
121+     if  cache :
122+         cache_key  =  create_normalized_cache_key (prompt )
123+         cache_read_start  =  time ()
124+         cached_result  =  get_from_cache_and_restore_stats (cache , cache_key )
125+         if  cached_result  is  not None :
126+             cache_read_duration  =  time () -  cache_read_start 
127+             llm_call_info  =  llm_call_info_var .get ()
128+             if  llm_call_info :
129+                 llm_call_info .from_cache  =  True 
130+                 llm_call_info .duration  =  cache_read_duration 
131+                 llm_call_info .started_at  =  time () -  cache_read_duration 
132+                 llm_call_info .finished_at  =  time ()
133+ 
134+             log .debug ("Jailbreak detection cache hit" )
135+             return  cached_result ["jailbreak" ]
136+ 
137+     jailbreak_result  =  None 
138+     api_start_time  =  time ()
139+ 
105140    if  not  jailbreak_api_url  and  not  nim_base_url :
106141        from  nemoguardrails .library .jailbreak_detection .model_based .checks  import  (
107142            check_jailbreak ,
@@ -114,32 +149,64 @@ async def jailbreak_detection_model(
114149        try :
115150            jailbreak  =  check_jailbreak (prompt = prompt )
116151            log .info (f"Local model jailbreak detection result: { jailbreak }  )
117-             return  jailbreak ["jailbreak" ]
152+             jailbreak_result   =  jailbreak ["jailbreak" ]
118153        except  RuntimeError  as  e :
119154            log .error (f"Jailbreak detection model not available: { e }  )
120-             return  False 
155+             jailbreak_result   =  False 
121156        except  ImportError  as  e :
122157            log .error (
123158                f"Failed to import required dependencies for local model. Install scikit-learn and torch, or use NIM-based approach" ,
124159                exc_info = e ,
125160            )
126-             return  False 
127- 
128-     if  nim_base_url :
129-         jailbreak  =  await  jailbreak_nim_request (
130-             prompt = prompt ,
131-             nim_url = nim_base_url ,
132-             nim_auth_token = nim_auth_token ,
133-             nim_classification_path = nim_classification_path ,
134-         )
135-     elif  jailbreak_api_url :
136-         jailbreak  =  await  jailbreak_detection_model_request (
137-             prompt = prompt , api_url = jailbreak_api_url 
138-         )
139- 
140-     if  jailbreak  is  None :
141-         log .warning ("Jailbreak endpoint not set up properly." )
142-         # If no result, assume not a jailbreak 
143-         return  False 
161+             jailbreak_result  =  False 
144162    else :
145-         return  jailbreak 
163+         if  nim_base_url :
164+             jailbreak  =  await  jailbreak_nim_request (
165+                 prompt = prompt ,
166+                 nim_url = nim_base_url ,
167+                 nim_auth_token = nim_auth_token ,
168+                 nim_classification_path = nim_classification_path ,
169+             )
170+         elif  jailbreak_api_url :
171+             jailbreak  =  await  jailbreak_detection_model_request (
172+                 prompt = prompt , api_url = jailbreak_api_url 
173+             )
174+ 
175+         if  jailbreak  is  None :
176+             log .warning ("Jailbreak endpoint not set up properly." )
177+             jailbreak_result  =  False 
178+         else :
179+             jailbreak_result  =  jailbreak 
180+ 
181+     api_duration  =  time () -  api_start_time 
182+ 
183+     llm_call_info  =  llm_call_info_var .get ()
184+     if  llm_call_info :
185+         llm_call_info .from_cache  =  False 
186+         llm_call_info .duration  =  api_duration 
187+         llm_call_info .started_at  =  api_start_time 
188+         llm_call_info .finished_at  =  time ()
189+ 
190+         processing_log  =  processing_log_var .get ()
191+         if  processing_log  is  not None :
192+             processing_log .append (
193+                 {
194+                     "type" : "llm_call_info" ,
195+                     "timestamp" : time (),
196+                     "data" : llm_call_info ,
197+                 }
198+             )
199+ 
200+     if  cache :
201+         from  nemoguardrails .llm .cache .utils  import  extract_llm_metadata_for_cache 
202+ 
203+         cache_key  =  create_normalized_cache_key (prompt )
204+         cache_entry : CacheEntry  =  {
205+             "result" : {"jailbreak" : jailbreak_result },
206+             "llm_stats" : None ,
207+             "llm_metadata" : extract_llm_metadata_for_cache (),
208+         }
209+         cache .put (cache_key , cache_entry )
210+         log .debug ("Jailbreak detection result cached" )
211+ 
212+     return  jailbreak_result 
0 commit comments