|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -from typing import TYPE_CHECKING, Any, Optional, Sequence, cast |
4 | 3 |
|
5 | | -from cleanlab_tlm.utils.rag import Eval, TrustworthyRAGScore, get_default_evals |
6 | | - |
7 | | -from cleanlab_codex.types.validator import ThresholdedTrustworthyRAGScore |
8 | | - |
9 | | -if TYPE_CHECKING: |
10 | | - from cleanlab_codex.validator import BadResponseThresholds |
11 | | - |
12 | | - |
13 | | -"""Evaluation metrics (excluding trustworthiness) that are used to determine if a response is bad.""" |
14 | | -DEFAULT_EVAL_METRICS = ["response_helpfulness"] |
15 | | - |
16 | | -# Simple mappings for is_bad keys |
17 | | -_SCORE_TO_IS_BAD_KEY = { |
18 | | - "trustworthiness": "is_not_trustworthy", |
19 | | - "query_ease": "is_not_query_easy", |
20 | | - "response_helpfulness": "is_not_response_helpful", |
21 | | - "context_sufficiency": "is_not_context_sufficient", |
22 | | -} |
23 | | - |
24 | | - |
25 | | -def get_default_evaluations() -> list[Eval]: |
26 | | - """Get the default evaluations for the TrustworthyRAG. |
27 | | -
|
28 | | - Note: |
29 | | - This excludes trustworthiness, which is automatically computed by TrustworthyRAG. |
30 | | - """ |
31 | | - return [evaluation for evaluation in get_default_evals() if evaluation.name in DEFAULT_EVAL_METRICS] |
32 | | - |
33 | | - |
34 | | -def get_default_trustworthyrag_config() -> dict[str, Any]: |
35 | | - """Get the default configuration for the TrustworthyRAG.""" |
36 | | - return { |
37 | | - "options": { |
38 | | - "log": ["explanation"], |
39 | | - }, |
40 | | - } |
41 | | - |
42 | | - |
43 | | -def update_scores_based_on_thresholds( |
44 | | - scores: TrustworthyRAGScore | Sequence[TrustworthyRAGScore], thresholds: BadResponseThresholds |
45 | | -) -> ThresholdedTrustworthyRAGScore: |
46 | | - """Adds a `is_bad` flag to the scores dictionaries based on the thresholds.""" |
47 | | - |
48 | | - # Helper function to check if a score is bad |
49 | | - def is_bad(score: Optional[float], threshold: float) -> bool: |
50 | | - return score is not None and score < threshold |
51 | | - |
52 | | - if isinstance(scores, Sequence): |
53 | | - raise NotImplementedError("Batching is not supported yet.") |
54 | | - |
55 | | - thresholded_scores = {} |
56 | | - for eval_name, score_dict in scores.items(): |
57 | | - thresholded_scores[eval_name] = { |
58 | | - **score_dict, |
59 | | - "is_bad": is_bad(score_dict["score"], thresholds.get_threshold(eval_name)), |
60 | | - } |
61 | | - return cast(ThresholdedTrustworthyRAGScore, thresholded_scores) |
62 | | - |
63 | | - |
64 | | -def process_score_metadata(scores: ThresholdedTrustworthyRAGScore, thresholds: BadResponseThresholds) -> dict[str, Any]: |
65 | | - """Process scores into metadata format with standardized keys. |
| 4 | +def validate_thresholds(thresholds: dict[str, float]) -> None: |
| 5 | + """Validate that all threshold values are between 0 and 1. |
66 | 6 |
|
67 | 7 | Args: |
68 | | - scores: The ThresholdedTrustworthyRAGScore containing evaluation results |
69 | | - thresholds: The BadResponseThresholds configuration |
| 8 | + thresholds: Dictionary mapping eval names to their threshold values. |
70 | 9 |
|
71 | | - Returns: |
72 | | - dict: A dictionary containing evaluation scores and their corresponding metadata |
| 10 | + Raises: |
| 11 | + TypeError: If any threshold value is not a number. |
| 12 | + ValueError: If any threshold value is not between 0 and 1. |
73 | 13 | """ |
74 | | - metadata: dict[str, Any] = {} |
75 | | - |
76 | | - # Process scores and add to metadata |
77 | | - for metric, score_data in scores.items(): |
78 | | - metadata[metric] = score_data["score"] |
79 | | - |
80 | | - # Add is_bad flags with standardized naming |
81 | | - is_bad_key = _SCORE_TO_IS_BAD_KEY.get(metric, f"is_not_{metric}") |
82 | | - metadata[is_bad_key] = score_data["is_bad"] |
83 | | - |
84 | | - # Special case for trustworthiness explanation |
85 | | - if metric == "trustworthiness" and "log" in score_data and "explanation" in score_data["log"]: |
86 | | - metadata["explanation"] = score_data["log"]["explanation"] |
87 | | - |
88 | | - # Add thresholds to metadata |
89 | | - thresholds_dict = thresholds.model_dump() |
90 | | - for metric in {k for k in scores if k not in thresholds_dict}: |
91 | | - thresholds_dict[metric] = thresholds.get_threshold(metric) |
92 | | - metadata["thresholds"] = thresholds_dict |
93 | | - |
94 | | - # TODO: Remove this as the backend can infer this from the is_bad flags |
95 | | - metadata["label"] = _get_label(metadata) |
96 | | - |
97 | | - return metadata |
98 | | - |
99 | | - |
100 | | -def _get_label(metadata: dict[str, Any]) -> str: |
101 | | - def is_bad(metric: str) -> bool: |
102 | | - return bool(metadata.get(_SCORE_TO_IS_BAD_KEY[metric], False)) |
103 | | - |
104 | | - if is_bad("context_sufficiency"): |
105 | | - return "search_failure" |
106 | | - if is_bad("response_helpfulness") or is_bad("query_ease"): |
107 | | - return "unhelpful" |
108 | | - if is_bad("trustworthiness"): |
109 | | - return "hallucination" |
110 | | - return "other_issues" |
| 14 | + for eval_name, threshold in thresholds.items(): |
| 15 | + if not isinstance(threshold, (int, float)): |
| 16 | + error_msg = f"Threshold for {eval_name} must be a number, got {type(threshold)}" |
| 17 | + raise TypeError(error_msg) |
| 18 | + if not 0 <= float(threshold) <= 1: |
| 19 | + error_msg = f"Threshold for {eval_name} must be between 0 and 1, got {threshold}" |
| 20 | + raise ValueError(error_msg) |
0 commit comments