Skip to content

Commit b471371

Browse files
committed
clean up imports, type hints and docs
1 parent 873f552 commit b471371

File tree

3 files changed

+38
-81
lines changed

3 files changed

+38
-81
lines changed
Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,12 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Any
3+
from typing import TYPE_CHECKING, Any, Optional, Sequence, cast
44

5-
from cleanlab_codex.utils.errors import MissingDependencyError
5+
from cleanlab_tlm.utils.rag import Eval, TrustworthyRAGScore, get_default_evals
66

7-
try:
8-
from cleanlab_tlm.utils.rag import Eval, TrustworthyRAGScore, get_default_evals
9-
except ImportError as e:
10-
raise MissingDependencyError(
11-
import_name=e.name or "cleanlab-tlm",
12-
package_url="https://github.com/cleanlab/cleanlab-tlm",
13-
) from e
7+
from cleanlab_codex.types.validator import ThresholdedTrustworthyRAGScore
148

159
if TYPE_CHECKING:
16-
from cleanlab_codex.types.validator import ThresholdedTrustworthyRAGScore
1710
from cleanlab_codex.validator import BadResponseThresholds
1811

1912

@@ -40,26 +33,21 @@ def get_default_trustworthyrag_config() -> dict[str, Any]:
4033

4134

4235
def update_scores_based_on_thresholds(
43-
scores: ThresholdedTrustworthyRAGScore, thresholds: BadResponseThresholds
44-
) -> None:
36+
scores: TrustworthyRAGScore | Sequence[TrustworthyRAGScore], thresholds: BadResponseThresholds
37+
) -> ThresholdedTrustworthyRAGScore:
4538
"""Adds a `is_bad` flag to the scores dictionaries based on the thresholds."""
46-
for eval_name, score_dict in scores.items():
47-
score_dict.setdefault("is_bad", False)
48-
if (score := score_dict["score"]) is not None:
49-
score_dict["is_bad"] = score < thresholds.get_threshold(eval_name)
5039

40+
# Helper function to check if a score is bad
41+
def is_bad(score: Optional[float], threshold: float) -> bool:
42+
return score is not None and score < threshold
5143

52-
def is_bad_response(
53-
scores: TrustworthyRAGScore | ThresholdedTrustworthyRAGScore,
54-
thresholds: BadResponseThresholds,
55-
) -> bool:
56-
"""
57-
Check if the response is bad based on the scores computed by TrustworthyRAG and the config containing thresholds.
58-
"""
59-
for eval_metric, score_dict in scores.items():
60-
score = score_dict["score"]
61-
if score is None:
62-
continue
63-
if score < thresholds.get_threshold(eval_metric):
64-
return True
65-
return False
44+
if isinstance(scores, Sequence):
45+
raise NotImplementedError("Batching is not supported yet.")
46+
47+
thresholded_scores = {}
48+
for eval_name, score_dict in scores.items():
49+
thresholded_scores[eval_name] = {
50+
**score_dict,
51+
"is_bad": is_bad(score_dict["score"], thresholds.get_threshold(eval_name)),
52+
}
53+
return cast(ThresholdedTrustworthyRAGScore, thresholded_scores)

src/cleanlab_codex/validator.py

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,20 @@
44

55
from __future__ import annotations
66

7-
from typing import Any, Callable, Optional, cast
7+
from typing import TYPE_CHECKING, Any, Callable, Optional, cast
88

9+
from cleanlab_tlm import TrustworthyRAG
910
from pydantic import BaseModel, Field, field_validator
1011

1112
from cleanlab_codex.internal.validator import (
1213
get_default_evaluations,
1314
get_default_trustworthyrag_config,
1415
)
15-
from cleanlab_codex.internal.validator import is_bad_response as _is_bad_response
1616
from cleanlab_codex.internal.validator import update_scores_based_on_thresholds as _update_scores_based_on_thresholds
1717
from cleanlab_codex.project import Project
18-
from cleanlab_codex.types.validator import ThresholdedTrustworthyRAGScore
19-
from cleanlab_codex.utils.errors import MissingDependencyError
2018

21-
try:
22-
from cleanlab_tlm import TrustworthyRAG
23-
except ImportError as e:
24-
raise MissingDependencyError(
25-
import_name=e.name or "cleanlab-tlm",
26-
package_url="https://github.com/cleanlab/cleanlab-tlm",
27-
) from e
19+
if TYPE_CHECKING:
20+
from cleanlab_codex.types.validator import ThresholdedTrustworthyRAGScore
2821

2922

3023
class BadResponseThresholds(BaseModel):
@@ -141,8 +134,8 @@ def validate(
141134
142135
Returns:
143136
dict[str, Any]: A dictionary containing:
144-
- 'is_bad_response': True if the response is flagged as potentially bad, False otherwise.
145-
- 'expert_answer': Alternate SME-provided answer from Codex, or None if no answer could be found in the Codex Project.
137+
- 'expert_answer': Alternate SME-provided answer from Codex if the response was flagged as bad and an answer was found, or None otherwise.
138+
- 'is_bad_response': True if the response is flagged as potentially bad (when True, a lookup in Codex is performed), False otherwise.
146139
- Additional keys: Various keys from a [`ThresholdedTrustworthyRAGScore`](/cleanlab_codex/types/validator/#class-thresholdedtrustworthyragscore) dictionary, with raw scores from [TrustworthyRAG](/tlm/api/python/utils.rag/#class-trustworthyrag) for each evaluation metric. `is_bad` indicating whether the score is below the threshold.
147140
"""
148141
scores, is_bad_response = self.detect(query, context, response, prompt, form_prompt)
@@ -164,7 +157,7 @@ def detect(
164157
prompt: Optional[str] = None,
165158
form_prompt: Optional[Callable[[str, str], str]] = None,
166159
) -> tuple[ThresholdedTrustworthyRAGScore, bool]:
167-
"""Evaluate the response quality using TrustworthyRAG and determine if it is a bad response.
160+
"""Evaluate the response quality using TrustworthyRAG and determine if it is a bad response via thresholding.
168161
169162
Args:
170163
query (str): The user query that was used to generate the response.
@@ -178,21 +171,21 @@ def detect(
178171
- bool: True if the response is determined to be bad based on the evaluation scores
179172
and configured thresholds, False otherwise.
180173
"""
181-
scores = cast(
182-
ThresholdedTrustworthyRAGScore,
183-
self._tlm_rag.score(
184-
response=response,
185-
query=query,
186-
context=context,
187-
prompt=prompt,
188-
form_prompt=form_prompt,
189-
),
174+
scores = self._tlm_rag.score(
175+
response=response,
176+
query=query,
177+
context=context,
178+
prompt=prompt,
179+
form_prompt=form_prompt,
190180
)
191181

192-
_update_scores_based_on_thresholds(scores, thresholds=self._bad_response_thresholds)
182+
thresholded_scores = _update_scores_based_on_thresholds(
183+
scores=scores,
184+
thresholds=self._bad_response_thresholds,
185+
)
193186

194-
is_bad_response = _is_bad_response(scores, self._bad_response_thresholds)
195-
return scores, is_bad_response
187+
is_bad_response = any(score_dict["is_bad"] for score_dict in thresholded_scores.values())
188+
return thresholded_scores, is_bad_response
196189

197190
def remediate(self, query: str) -> str | None:
198191
"""Request a SME-provided answer for this query, if one is available in Codex.

tests/internal/test_validator.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
from typing import cast
22

3-
import pytest
43
from cleanlab_tlm.utils.rag import TrustworthyRAGScore
54

6-
from cleanlab_codex.internal.validator import get_default_evaluations, is_bad_response
5+
from cleanlab_codex.internal.validator import get_default_evaluations
76
from cleanlab_codex.validator import BadResponseThresholds
87

98

@@ -28,26 +27,3 @@ def make_is_bad_response_config(trustworthiness: float, response_helpfulness: fl
2827

2928
def test_get_default_evaluations() -> None:
3029
assert {evaluation.name for evaluation in get_default_evaluations()} == {"response_helpfulness"}
31-
32-
33-
class TestIsBadResponse:
34-
@pytest.fixture
35-
def scores(self) -> TrustworthyRAGScore:
36-
return make_scores(0.92, 0.75)
37-
38-
@pytest.fixture
39-
def custom_is_bad_response_config(self) -> BadResponseThresholds:
40-
return make_is_bad_response_config(0.6, 0.7)
41-
42-
def test_thresholds(self, scores: TrustworthyRAGScore) -> None:
43-
# High trustworthiness_threshold
44-
is_bad_response_config = make_is_bad_response_config(0.921, 0.5)
45-
assert is_bad_response(scores, is_bad_response_config)
46-
47-
# High response_helpfulness_threshold
48-
is_bad_response_config = make_is_bad_response_config(0.5, 0.751)
49-
assert is_bad_response(scores, is_bad_response_config)
50-
51-
def test_scores(self, custom_is_bad_response_config: BadResponseThresholds) -> None:
52-
scores = make_scores(0.59, 0.7)
53-
assert is_bad_response(scores, custom_is_bad_response_config)

0 commit comments

Comments
 (0)