1515
1616from pydantic import BaseModel , ConfigDict , Field
1717
18+ from cleanlab_codex .internal .tlm import TLM
1819from cleanlab_codex .internal .utils import generate_pydantic_model_docstring
1920from cleanlab_codex .types .response_validation import (
2021 AggregatedResponseValidationResult ,
2122 SingleResponseValidationResult ,
2223)
23- from cleanlab_codex .types .tlm import TLM
24+ from cleanlab_codex .types .tlm import TLMConfig
2425from cleanlab_codex .utils .errors import MissingDependencyError
2526from cleanlab_codex .utils .prompt import default_format_prompt
2627
3031_DEFAULT_FALLBACK_SIMILARITY_THRESHOLD : float = 0.7
3132_DEFAULT_TRUSTWORTHINESS_THRESHOLD : float = 0.5
3233_DEFAULT_UNHELPFULNESS_CONFIDENCE_THRESHOLD : float = 0.5
34+ _DEFAULT_TLM_CONFIG : TLMConfig = TLMConfig ()
3335
3436Query = str
3537Context = str
@@ -77,13 +79,12 @@ class BadResponseDetectionConfig(BaseModel):
7779 )
7880
7981 # Shared config (for untrustworthiness and unhelpfulness checks)
80- tlm : Optional [ TLM ] = Field (
81- default = None ,
82- description = "TLM model to use for evaluation (required for untrustworthiness and unhelpfulness checks) ." ,
82+ tlm_config : TLMConfig = Field (
83+ default = _DEFAULT_TLM_CONFIG ,
84+ description = "TLM model configuration to use for untrustworthiness and unhelpfulness checks." ,
8385 )
8486
8587
86- # hack to generate better documentation for help.cleanlab.ai
8788BadResponseDetectionConfig .__doc__ = f"""
8889{ BadResponseDetectionConfig .__doc__ }
8990
@@ -99,10 +100,11 @@ def is_bad_response(
99100 context : Optional [str ] = None ,
100101 query : Optional [str ] = None ,
101102 config : Union [BadResponseDetectionConfig , Dict [str , Any ]] = _DEFAULT_CONFIG ,
103+ codex_access_key : Optional [str ] = None ,
102104) -> AggregatedResponseValidationResult :
103105 """Run a series of checks to determine if a response is bad.
104106
105- The function returns a `AggregatedResponseValidationResult` object containing results from multiple validation checks.
107+ The function returns an `AggregatedResponseValidationResult` object containing results from multiple validation checks.
106108 If any check fails (detects an issue), the AggregatedResponseValidationResult will evaluate to `True` when used in a boolean context.
107109 This means code like `if is_bad_response(...)` will enter the if-block when problems are detected.
108110
@@ -146,28 +148,30 @@ def is_bad_response(
146148 )
147149 )
148150
149- can_run_untrustworthy_check = query is not None and context is not None and config .tlm is not None
151+ can_run_untrustworthy_check = query is not None and context is not None and config .tlm_config is not None
150152 if can_run_untrustworthy_check :
151153 # The if condition guarantees these are not None
152154 validation_checks .append (
153155 lambda : is_untrustworthy_response (
154156 response = response ,
155157 context = cast (str , context ),
156158 query = cast (str , query ),
157- tlm = cast ( TLM , config .tlm ) ,
159+ tlm_config = config .tlm_config ,
158160 trustworthiness_threshold = config .trustworthiness_threshold ,
159161 format_prompt = config .format_prompt ,
162+ codex_access_key = codex_access_key ,
160163 )
161164 )
162165
163- can_run_unhelpful_check = query is not None and config .tlm is not None
166+ can_run_unhelpful_check = query is not None and config .tlm_config is not None
164167 if can_run_unhelpful_check :
165168 validation_checks .append (
166169 lambda : is_unhelpful_response (
167170 response = response ,
168171 query = cast (str , query ),
169- tlm = cast ( TLM , config .tlm ) ,
172+ tlm_config = config .tlm_config ,
170173 confidence_score_threshold = config .unhelpfulness_confidence_threshold ,
174+ codex_access_key = codex_access_key ,
171175 )
172176 )
173177
@@ -238,9 +242,11 @@ def is_untrustworthy_response(
238242 response : str ,
239243 context : str ,
240244 query : str ,
241- tlm : TLM ,
245+ tlm_config : TLMConfig = _DEFAULT_TLM_CONFIG ,
242246 trustworthiness_threshold : float = _DEFAULT_TRUSTWORTHINESS_THRESHOLD ,
243247 format_prompt : Callable [[str , str ], str ] = default_format_prompt ,
248+ * ,
249+ codex_access_key : Optional [str ] = None ,
244250) -> SingleResponseValidationResult :
245251 """Check if a response is untrustworthy.
246252
@@ -252,7 +258,7 @@ def is_untrustworthy_response(
252258 response (str): The response to check from the assistant.
253259 context (str): The context information available for answering the query.
254260 query (str): The user's question or request.
255- tlm (TLM ): The TLM model to use for evaluation.
261+ tlm_config (TLMConfig ): The TLM configuration to use for evaluation.
256262 trustworthiness_threshold (float): Score threshold (0.0-1.0) under which a response is considered untrustworthy.
257263 Lower values allow less trustworthy responses. Default 0.5 means responses with scores less than 0.5 are considered untrustworthy.
258264 format_prompt (Callable[[str, str], str]): Function that takes (query, context) and returns a formatted prompt string.
@@ -266,8 +272,9 @@ def is_untrustworthy_response(
266272 response = response ,
267273 context = context ,
268274 query = query ,
269- tlm = tlm ,
275+ tlm_config = tlm_config ,
270276 format_prompt = format_prompt ,
277+ codex_access_key = codex_access_key ,
271278 )
272279 return SingleResponseValidationResult (
273280 name = "untrustworthy" ,
@@ -281,8 +288,10 @@ def score_untrustworthy_response(
281288 response : str ,
282289 context : str ,
283290 query : str ,
284- tlm : TLM ,
291+ tlm_config : TLMConfig = _DEFAULT_TLM_CONFIG ,
285292 format_prompt : Callable [[str , str ], str ] = default_format_prompt ,
293+ * ,
294+ codex_access_key : Optional [str ] = None ,
286295) -> float :
287296 """Scores a response's trustworthiness using [TLM](/tlm), given a context and query.
288297
@@ -298,24 +307,20 @@ def score_untrustworthy_response(
298307 Returns:
299308 float: The score of the response, between 0.0 and 1.0. A lower score indicates the response is less trustworthy.
300309 """
301- try :
302- from cleanlab_tlm import TLM # noqa: F401
303- except ImportError as e :
304- raise MissingDependencyError (
305- import_name = e .name or "cleanlab_tlm" ,
306- package_name = "cleanlab-tlm" ,
307- package_url = "https://github.com/cleanlab/cleanlab-tlm" ,
308- ) from e
309310 prompt = format_prompt (query , context )
310- result = tlm .get_trustworthiness_score (prompt , response )
311- return float (result ["trustworthiness_score" ])
311+ result = TLM .from_config (tlm_config , codex_access_key = codex_access_key ).get_trustworthiness_score (
312+ prompt , response = response
313+ )
314+ return float (result .trustworthiness_score )
312315
313316
314317def is_unhelpful_response (
315318 response : str ,
316319 query : str ,
317- tlm : TLM ,
320+ tlm_config : TLMConfig = _DEFAULT_TLM_CONFIG ,
318321 confidence_score_threshold : float = _DEFAULT_UNHELPFULNESS_CONFIDENCE_THRESHOLD ,
322+ * ,
323+ codex_access_key : Optional [str ] = None ,
319324) -> SingleResponseValidationResult :
320325 """Check if a response is unhelpful by asking [TLM](/tlm) to evaluate it.
321326
@@ -327,14 +332,14 @@ def is_unhelpful_response(
327332 Args:
328333 response (str): The response to check.
329334 query (str): User query that will be used to evaluate if the response is helpful.
330- tlm (TLM ): The TLM model to use for evaluation.
335+ tlm_config (TLMConfig ): The configuration
331336 confidence_score_threshold (float): Confidence threshold (0.0-1.0) above which a response is considered unhelpful.
332337 E.g. if confidence_score_threshold is 0.5, then responses with scores higher than 0.5 are considered unhelpful.
333338
334339 Returns:
335340 SingleResponseValidationResult: The results of the validation check.
336341 """
337- score : float = score_unhelpful_response (response , query , tlm )
342+ score : float = score_unhelpful_response (response , query , tlm_config , codex_access_key = codex_access_key )
338343
339344 # Current implementation of `score_unhelpful_response` produces a score where a higher value means the response if more likely to be unhelpful
340345 # Changing the TLM prompt used in `score_unhelpful_response` may require restructuring the logic for `fails_check` and potentially adjusting
@@ -350,27 +355,20 @@ def is_unhelpful_response(
350355def score_unhelpful_response (
351356 response : str ,
352357 query : str ,
353- tlm : TLM ,
358+ tlm_config : TLMConfig = _DEFAULT_TLM_CONFIG ,
359+ * ,
360+ codex_access_key : Optional [str ] = None ,
354361) -> float :
355362 """Scores a response's unhelpfulness using [TLM](/tlm), given a query.
356363
357364 Args:
358365 response (str): The response to check.
359366 query (str): User query that will be used to evaluate if the response is helpful.
360- tlm (TLM ): The TLM model to use for evaluation.
367+ tlm_config (TLMConfig ): The TLM model to use for evaluation.
361368
362369 Returns:
363370 float: The score of the response, between 0.0 and 1.0. A higher score corresponds to a less helpful response.
364371 """
365- try :
366- from cleanlab_tlm import TLM # noqa: F401
367- except ImportError as e :
368- raise MissingDependencyError (
369- import_name = e .name or "cleanlab_tlm" ,
370- package_name = "cleanlab-tlm" ,
371- package_url = "https://github.com/cleanlab/cleanlab-tlm" ,
372- ) from e
373-
374372 # IMPORTANT: The current implementation couples three things that must stay in sync:
375373 # 1. The question phrasing ("is unhelpful?")
376374 # 2. The expected_unhelpful_response ("Yes")
@@ -405,5 +403,7 @@ def score_unhelpful_response(
405403 f"AI Assistant Response: { response } \n \n "
406404 f"{ question } "
407405 )
408- result = tlm .get_trustworthiness_score (prompt , response = expected_unhelpful_response )
409- return float (result ["trustworthiness_score" ])
406+ result = TLM .from_config (tlm_config , codex_access_key = codex_access_key ).get_trustworthiness_score (
407+ prompt , response = expected_unhelpful_response
408+ )
409+ return float (result .trustworthiness_score )
0 commit comments