Skip to content

Commit 36bd355

Browse files
refactor Validator to use /validate API in Codex BE (#83)
* refactor validator to wrap Codex API * latest changes (wip?) * initial refactor changes * fix test * deprecate project.query and project.add_entries * include metadata in validate args * remove whitespace --------- Co-authored-by: Kelsey Wong <[email protected]>
1 parent 4561279 commit 36bd355

File tree

6 files changed

+172
-652
lines changed

6 files changed

+172
-652
lines changed
Lines changed: 13 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -1,110 +1,20 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Any, Optional, Sequence, cast
43

5-
from cleanlab_tlm.utils.rag import Eval, TrustworthyRAGScore, get_default_evals
6-
7-
from cleanlab_codex.types.validator import ThresholdedTrustworthyRAGScore
8-
9-
if TYPE_CHECKING:
10-
from cleanlab_codex.validator import BadResponseThresholds
11-
12-
13-
"""Evaluation metrics (excluding trustworthiness) that are used to determine if a response is bad."""
14-
DEFAULT_EVAL_METRICS = ["response_helpfulness"]
15-
16-
# Simple mappings for is_bad keys
17-
_SCORE_TO_IS_BAD_KEY = {
18-
"trustworthiness": "is_not_trustworthy",
19-
"query_ease": "is_not_query_easy",
20-
"response_helpfulness": "is_not_response_helpful",
21-
"context_sufficiency": "is_not_context_sufficient",
22-
}
23-
24-
25-
def get_default_evaluations() -> list[Eval]:
26-
"""Get the default evaluations for the TrustworthyRAG.
27-
28-
Note:
29-
This excludes trustworthiness, which is automatically computed by TrustworthyRAG.
30-
"""
31-
return [evaluation for evaluation in get_default_evals() if evaluation.name in DEFAULT_EVAL_METRICS]
32-
33-
34-
def get_default_trustworthyrag_config() -> dict[str, Any]:
35-
"""Get the default configuration for the TrustworthyRAG."""
36-
return {
37-
"options": {
38-
"log": ["explanation"],
39-
},
40-
}
41-
42-
43-
def update_scores_based_on_thresholds(
44-
scores: TrustworthyRAGScore | Sequence[TrustworthyRAGScore], thresholds: BadResponseThresholds
45-
) -> ThresholdedTrustworthyRAGScore:
46-
"""Adds a `is_bad` flag to the scores dictionaries based on the thresholds."""
47-
48-
# Helper function to check if a score is bad
49-
def is_bad(score: Optional[float], threshold: float) -> bool:
50-
return score is not None and score < threshold
51-
52-
if isinstance(scores, Sequence):
53-
raise NotImplementedError("Batching is not supported yet.")
54-
55-
thresholded_scores = {}
56-
for eval_name, score_dict in scores.items():
57-
thresholded_scores[eval_name] = {
58-
**score_dict,
59-
"is_bad": is_bad(score_dict["score"], thresholds.get_threshold(eval_name)),
60-
}
61-
return cast(ThresholdedTrustworthyRAGScore, thresholded_scores)
62-
63-
64-
def process_score_metadata(scores: ThresholdedTrustworthyRAGScore, thresholds: BadResponseThresholds) -> dict[str, Any]:
65-
"""Process scores into metadata format with standardized keys.
4+
def validate_thresholds(thresholds: dict[str, float]) -> None:
5+
"""Validate that all threshold values are between 0 and 1.
666
677
Args:
68-
scores: The ThresholdedTrustworthyRAGScore containing evaluation results
69-
thresholds: The BadResponseThresholds configuration
8+
thresholds: Dictionary mapping eval names to their threshold values.
709
71-
Returns:
72-
dict: A dictionary containing evaluation scores and their corresponding metadata
10+
Raises:
11+
TypeError: If any threshold value is not a number.
12+
ValueError: If any threshold value is not between 0 and 1.
7313
"""
74-
metadata: dict[str, Any] = {}
75-
76-
# Process scores and add to metadata
77-
for metric, score_data in scores.items():
78-
metadata[metric] = score_data["score"]
79-
80-
# Add is_bad flags with standardized naming
81-
is_bad_key = _SCORE_TO_IS_BAD_KEY.get(metric, f"is_not_{metric}")
82-
metadata[is_bad_key] = score_data["is_bad"]
83-
84-
# Special case for trustworthiness explanation
85-
if metric == "trustworthiness" and "log" in score_data and "explanation" in score_data["log"]:
86-
metadata["explanation"] = score_data["log"]["explanation"]
87-
88-
# Add thresholds to metadata
89-
thresholds_dict = thresholds.model_dump()
90-
for metric in {k for k in scores if k not in thresholds_dict}:
91-
thresholds_dict[metric] = thresholds.get_threshold(metric)
92-
metadata["thresholds"] = thresholds_dict
93-
94-
# TODO: Remove this as the backend can infer this from the is_bad flags
95-
metadata["label"] = _get_label(metadata)
96-
97-
return metadata
98-
99-
100-
def _get_label(metadata: dict[str, Any]) -> str:
101-
def is_bad(metric: str) -> bool:
102-
return bool(metadata.get(_SCORE_TO_IS_BAD_KEY[metric], False))
103-
104-
if is_bad("context_sufficiency"):
105-
return "search_failure"
106-
if is_bad("response_helpfulness") or is_bad("query_ease"):
107-
return "unhelpful"
108-
if is_bad("trustworthiness"):
109-
return "hallucination"
110-
return "other_issues"
14+
for eval_name, threshold in thresholds.items():
15+
if not isinstance(threshold, (int, float)):
16+
error_msg = f"Threshold for {eval_name} must be a number, got {type(threshold)}"
17+
raise TypeError(error_msg)
18+
if not 0 <= float(threshold) <= 1:
19+
error_msg = f"Threshold for {eval_name} must be between 0 and 1, got {threshold}"
20+
raise ValueError(error_msg)

src/cleanlab_codex/project.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
from __future__ import annotations
44

5+
import warnings
56
from datetime import datetime
67
from typing import TYPE_CHECKING as _TYPE_CHECKING
7-
from typing import Any, Optional
8+
from typing import Any, Dict, List, Literal, Optional
89

910
from codex import AuthenticationError
1011

@@ -17,6 +18,8 @@
1718
from datetime import datetime
1819

1920
from codex import Codex as _Codex
21+
from codex.types.project_validate_params import Options as ProjectValidateOptions
22+
from codex.types.project_validate_response import ProjectValidateResponse
2023

2124
from cleanlab_codex.types.entry import EntryCreate
2225

@@ -152,7 +155,7 @@ def create_access_key(
152155
raise AuthenticationError(_ERROR_CREATE_ACCESS_KEY, response=e.response, body=e.body) from e
153156

154157
def add_entries(self, entries: list[EntryCreate]) -> None:
155-
"""Add a list of entries to this Codex project. Must be authenticated with a user-level API key to use this method.
158+
"""[DEPRECATED] Add a list of entries to this Codex project. Must be authenticated with a user-level API key to use this method.
156159
See [`Client.create_project()`](/codex/api/python/client#method-create_project) or [`Client.get_project()`](/codex/api/python/client#method-get_project).
157160
158161
Args:
@@ -161,6 +164,11 @@ def add_entries(self, entries: list[EntryCreate]) -> None:
161164
Raises:
162165
AuthenticationError: If the Project was created from a project-level access key instead of a [Client instance](/codex/api/python/client#class-client).
163166
"""
167+
warnings.warn(
168+
"Project.add_entries() is deprecated and will be removed in a future release. ",
169+
FutureWarning,
170+
stacklevel=2,
171+
)
164172
try:
165173
# TODO: implement batch creation of entries in backend and update this function
166174
for entry in entries:
@@ -181,7 +189,7 @@ def query(
181189
metadata: Optional[dict[str, Any]] = None,
182190
_analytics_metadata: Optional[_AnalyticsMetadata] = None,
183191
) -> tuple[Optional[str], Entry]:
184-
"""Query Codex to check if this project contains an answer to the question. If the question is not yet in the project, it will be added for SME review.
192+
"""[DEPRECATED] Query Codex to check if this project contains an answer to the question. If the question is not yet in the project, it will be added for SME review.
185193
186194
Args:
187195
question (str): The question to ask the Codex API.
@@ -193,6 +201,11 @@ def query(
193201
If Codex is able to answer the question, the first element will be the answer returned by Codex and the second element will be the existing [`Entry`](/codex/api/python/types.entry#class-entry) in the Codex project.
194202
If Codex is unable to answer the question, the first element will be `fallback_answer` if provided, otherwise None. The second element will be a new [`Entry`](/codex/api/python/types.entry#class-entry) in the Codex project.
195203
"""
204+
warnings.warn(
205+
"Project.query() is deprecated and will be removed in a future release. Use the Project.validate() function instead.",
206+
FutureWarning,
207+
stacklevel=2,
208+
)
196209
if not _analytics_metadata:
197210
_analytics_metadata = _AnalyticsMetadata(integration_type=IntegrationType.BACKUP)
198211

@@ -213,7 +226,10 @@ def _query_project(
213226
) -> tuple[Optional[str], Entry]:
214227
extra_headers = analytics_metadata.to_headers() if analytics_metadata else None
215228
query_res = self._sdk_client.projects.entries.query(
216-
self._id, question=question, client_metadata=client_metadata, extra_headers=extra_headers
229+
self._id,
230+
question=question,
231+
client_metadata=client_metadata,
232+
extra_headers=extra_headers,
217233
)
218234

219235
entry = Entry.model_validate(query_res.entry.model_dump())
@@ -222,5 +238,30 @@ def _query_project(
222238

223239
return fallback_answer, entry
224240

225-
def increment_queries(self) -> None:
226-
self._sdk_client.projects.increment_queries(self._id)
241+
def validate(
242+
self,
243+
context: str,
244+
prompt: str,
245+
query: str,
246+
response: str,
247+
*,
248+
constrain_outputs: Optional[List[str]] = None,
249+
custom_metadata: Optional[object] = None,
250+
eval_scores: Optional[Dict[str, float]] = None,
251+
custom_eval_thresholds: Optional[Dict[str, float]] = None,
252+
options: Optional[ProjectValidateOptions] = None,
253+
quality_preset: Literal["best", "high", "medium", "low", "base"] = "medium",
254+
) -> ProjectValidateResponse:
255+
return self._sdk_client.projects.validate(
256+
self._id,
257+
context=context,
258+
prompt=prompt,
259+
query=query,
260+
response=response,
261+
constrain_outputs=constrain_outputs,
262+
custom_eval_thresholds=custom_eval_thresholds,
263+
custom_metadata=custom_metadata,
264+
eval_scores=eval_scores,
265+
options=options,
266+
quality_preset=quality_preset,
267+
)

0 commit comments

Comments
 (0)