Skip to content

Commit c37f7a3

Browse files
committed
feat(cache): add caching support for jailbreak detection (#1458)
* feat(cache): add LLM metadata caching for model and provider information Extends the cache system to store and restore LLM metadata (model name and provider name) alongside cache entries. This allows cached results to maintain provenance information about which model and provider generated the original response. - Added LLMMetadataDict and LLMCacheData TypedDict definitions for type safety - Extended CacheEntry to include optional llm_metadata field - Implemented extract_llm_metadata_for_cache() to capture model and provider info from context - Implemented restore_llm_metadata_from_cache() to restore metadata when retrieving cached results - Updated get_from_cache_and_restore_stats() to handle metadata extraction and restoration - Added comprehensive test coverage for metadata caching functionalit * feat(cache): add caching support for jailbreak detection
1 parent f724e19 commit c37f7a3

File tree

4 files changed

+387
-23
lines changed

4 files changed

+387
-23
lines changed

nemoguardrails/library/jailbreak_detection/actions.py

Lines changed: 89 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,25 @@
3030

3131
import logging
3232
import os
33-
from typing import Optional
33+
from time import time
34+
from typing import Dict, Optional
3435

3536
from nemoguardrails.actions import action
37+
from nemoguardrails.context import llm_call_info_var
3638
from nemoguardrails.library.jailbreak_detection.request import (
3739
jailbreak_detection_heuristics_request,
3840
jailbreak_detection_model_request,
3941
jailbreak_nim_request,
4042
)
43+
from nemoguardrails.llm.cache import CacheInterface
44+
from nemoguardrails.llm.cache.utils import (
45+
CacheEntry,
46+
create_normalized_cache_key,
47+
get_from_cache_and_restore_stats,
48+
)
4149
from nemoguardrails.llm.taskmanager import LLMTaskManager
50+
from nemoguardrails.logging.explain import LLMCallInfo
51+
from nemoguardrails.logging.processing_log import processing_log_var
4252

4353
log = logging.getLogger(__name__)
4454

@@ -89,6 +99,7 @@ async def jailbreak_detection_heuristics(
8999
async def jailbreak_detection_model(
90100
llm_task_manager: LLMTaskManager,
91101
context: Optional[dict] = None,
102+
model_caches: Optional[Dict[str, CacheInterface]] = None,
92103
) -> bool:
93104
"""Uses a trained classifier to determine if a user input is a jailbreak attempt"""
94105
prompt: str = ""
@@ -102,6 +113,30 @@ async def jailbreak_detection_model(
102113
if context is not None:
103114
prompt = context.get("user_message", "")
104115

116+
# we do this as a hack to treat this action as an LLM call for tracing
117+
llm_call_info_var.set(LLMCallInfo(task="jailbreak_detection_model"))
118+
119+
cache = model_caches.get("jailbreak_detection") if model_caches else None
120+
121+
if cache:
122+
cache_key = create_normalized_cache_key(prompt)
123+
cache_read_start = time()
124+
cached_result = get_from_cache_and_restore_stats(cache, cache_key)
125+
if cached_result is not None:
126+
cache_read_duration = time() - cache_read_start
127+
llm_call_info = llm_call_info_var.get()
128+
if llm_call_info:
129+
llm_call_info.from_cache = True
130+
llm_call_info.duration = cache_read_duration
131+
llm_call_info.started_at = time() - cache_read_duration
132+
llm_call_info.finished_at = time()
133+
134+
log.debug("Jailbreak detection cache hit")
135+
return cached_result["jailbreak"]
136+
137+
jailbreak_result = None
138+
api_start_time = time()
139+
105140
if not jailbreak_api_url and not nim_base_url:
106141
from nemoguardrails.library.jailbreak_detection.model_based.checks import (
107142
check_jailbreak,
@@ -114,32 +149,64 @@ async def jailbreak_detection_model(
114149
try:
115150
jailbreak = check_jailbreak(prompt=prompt)
116151
log.info(f"Local model jailbreak detection result: {jailbreak}")
117-
return jailbreak["jailbreak"]
152+
jailbreak_result = jailbreak["jailbreak"]
118153
except RuntimeError as e:
119154
log.error(f"Jailbreak detection model not available: {e}")
120-
return False
155+
jailbreak_result = False
121156
except ImportError as e:
122157
log.error(
123158
f"Failed to import required dependencies for local model. Install scikit-learn and torch, or use NIM-based approach",
124159
exc_info=e,
125160
)
126-
return False
127-
128-
if nim_base_url:
129-
jailbreak = await jailbreak_nim_request(
130-
prompt=prompt,
131-
nim_url=nim_base_url,
132-
nim_auth_token=nim_auth_token,
133-
nim_classification_path=nim_classification_path,
134-
)
135-
elif jailbreak_api_url:
136-
jailbreak = await jailbreak_detection_model_request(
137-
prompt=prompt, api_url=jailbreak_api_url
138-
)
139-
140-
if jailbreak is None:
141-
log.warning("Jailbreak endpoint not set up properly.")
142-
# If no result, assume not a jailbreak
143-
return False
161+
jailbreak_result = False
144162
else:
145-
return jailbreak
163+
if nim_base_url:
164+
jailbreak = await jailbreak_nim_request(
165+
prompt=prompt,
166+
nim_url=nim_base_url,
167+
nim_auth_token=nim_auth_token,
168+
nim_classification_path=nim_classification_path,
169+
)
170+
elif jailbreak_api_url:
171+
jailbreak = await jailbreak_detection_model_request(
172+
prompt=prompt, api_url=jailbreak_api_url
173+
)
174+
175+
if jailbreak is None:
176+
log.warning("Jailbreak endpoint not set up properly.")
177+
jailbreak_result = False
178+
else:
179+
jailbreak_result = jailbreak
180+
181+
api_duration = time() - api_start_time
182+
183+
llm_call_info = llm_call_info_var.get()
184+
if llm_call_info:
185+
llm_call_info.from_cache = False
186+
llm_call_info.duration = api_duration
187+
llm_call_info.started_at = api_start_time
188+
llm_call_info.finished_at = time()
189+
190+
processing_log = processing_log_var.get()
191+
if processing_log is not None:
192+
processing_log.append(
193+
{
194+
"type": "llm_call_info",
195+
"timestamp": time(),
196+
"data": llm_call_info,
197+
}
198+
)
199+
200+
if cache:
201+
from nemoguardrails.llm.cache.utils import extract_llm_metadata_for_cache
202+
203+
cache_key = create_normalized_cache_key(prompt)
204+
cache_entry: CacheEntry = {
205+
"result": {"jailbreak": jailbreak_result},
206+
"llm_stats": None,
207+
"llm_metadata": extract_llm_metadata_for_cache(),
208+
}
209+
cache.put(cache_key, cache_entry)
210+
log.debug("Jailbreak detection result cached")
211+
212+
return jailbreak_result

nemoguardrails/llm/cache/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,9 @@ def get_from_cache_and_restore_stats(
182182
if cached_metadata:
183183
restore_llm_metadata_from_cache(cached_metadata)
184184

185+
if cached_metadata:
186+
restore_llm_metadata_from_cache(cached_metadata)
187+
185188
processing_log = processing_log_var.get()
186189
if processing_log is not None:
187190
llm_call_info = llm_call_info_var.get()

nemoguardrails/rails/llm/llmrails.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,7 @@ def _init_llms(self):
481481
llms = dict()
482482

483483
for llm_config in self.config.models:
484-
if llm_config.type == "embeddings":
484+
if llm_config.type in ["embeddings", "jailbreak_detection"]:
485485
continue
486486

487487
# If a constructor LLM is provided, skip initializing any 'main' model from config

0 commit comments

Comments
 (0)