Skip to content

Commit 8afb99a

Browse files
jcpagadora737copybara-github
authored andcommitted
feat: Support per-eval case and per-invocation rubrics in rubric-based evaluators
Co-authored-by: Joseph Pagadora <[email protected]> PiperOrigin-RevId: 853820099
1 parent 6887913 commit 8afb99a

File tree

8 files changed

+402
-89
lines changed

8 files changed

+402
-89
lines changed

src/google/adk/cli/cli_eval.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,9 +201,11 @@ def pretty_print_eval_result(eval_result: EvalCaseResult):
201201
for r in metric_result.criterion.rubrics
202202
}
203203
for rubric_score in metric_result.details.rubric_scores:
204-
rubric = rubrics_by_id.get(rubric_score.rubric_id)
204+
rubric_text = rubrics_by_id.get(rubric_score.rubric_id)
205+
if not rubric_text:
206+
rubric_text = rubric_score.rubric_id
205207
click.echo(
206-
f"Rubric: {rubric}, "
208+
f"Rubric: {rubric_text}, "
207209
f"Score: {rubric_score.score}, "
208210
f"Reasoning: {rubric_score.rationale}"
209211
)
@@ -243,6 +245,8 @@ def pretty_print_eval_result(eval_result: EvalCaseResult):
243245
}
244246
for rubric_score in metric_result.details.rubric_scores:
245247
rubric = rubrics_by_id.get(rubric_score.rubric_id)
248+
if not rubric:
249+
rubric = rubric_score.rubric_id
246250
row_data[f"Rubric: {rubric}"] = (
247251
f"Reasoning: {rubric_score.rationale}, "
248252
f"Score: {rubric_score.score}"

src/google/adk/evaluation/local_eval_service.py

Lines changed: 137 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from .eval_metrics import EvalMetricResult
4747
from .eval_metrics import EvalMetricResultDetails
4848
from .eval_metrics import EvalMetricResultPerInvocation
49+
from .eval_metrics import Rubric
4950
from .eval_result import EvalCaseResult
5051
from .eval_set import EvalCase
5152
from .eval_set_results_manager import EvalSetResultsManager
@@ -67,6 +68,46 @@ def _get_session_id() -> str:
6768
return f'{EVAL_SESSION_ID_PREFIX}{str(uuid.uuid4())}'
6869

6970

71+
def _add_rubrics_to_invocation(
72+
invocation: Invocation, rubrics_to_add: list[Rubric]
73+
):
74+
"""Adds rubrics to invocation, throwing ValueError on duplicate rubric_id."""
75+
if not invocation.rubrics:
76+
invocation.rubrics = []
77+
existing_ids = {r.rubric_id for r in invocation.rubrics}
78+
for rubric in rubrics_to_add:
79+
if rubric.rubric_id in existing_ids:
80+
raise ValueError(
81+
f"Rubric with rubric_id '{rubric.rubric_id}' already exists."
82+
)
83+
invocation.rubrics.append(rubric)
84+
existing_ids.add(rubric.rubric_id)
85+
86+
87+
def _copy_eval_case_rubrics_to_actual_invocations(
88+
eval_case: EvalCase, actual_invocations: list[Invocation]
89+
):
90+
"""Copies EvalCase level rubrics to all actual invocations."""
91+
if hasattr(eval_case, 'rubrics') and eval_case.rubrics:
92+
for invocation in actual_invocations:
93+
_add_rubrics_to_invocation(invocation, eval_case.rubrics)
94+
95+
96+
def _copy_invocation_rubrics_to_actual_invocations(
97+
expected_invocations: Optional[list[Invocation]],
98+
actual_invocations: list[Invocation],
99+
):
100+
"""Copies invocation level rubrics to corresponding actual invocations."""
101+
if expected_invocations:
102+
for actual_invocation, expected_invocation in zip(
103+
actual_invocations, expected_invocations
104+
):
105+
if expected_invocation.rubrics:
106+
_add_rubrics_to_invocation(
107+
actual_invocation, expected_invocation.rubrics
108+
)
109+
110+
70111
@experimental
71112
class LocalEvalService(BaseEvalService):
72113
"""An implementation of BaseEvalService, that runs the evals locally."""
@@ -249,76 +290,27 @@ async def _evaluate_single_inference_result(
249290
)
250291
)
251292

252-
for eval_metric in evaluate_config.eval_metrics:
253-
# Perform evaluation of the metric.
254-
try:
255-
with client_label_context(EVAL_CLIENT_LABEL):
256-
evaluation_result = await self._evaluate_metric(
257-
eval_metric=eval_metric,
258-
actual_invocations=inference_result.inferences,
259-
expected_invocations=eval_case.conversation,
260-
conversation_scenario=eval_case.conversation_scenario,
261-
)
262-
except Exception as e:
263-
# We intentionally catch the Exception as we don't want failures to
264-
# affect other metric evaluation.
265-
logger.error(
266-
"Metric evaluation failed for metric `%s` for eval case id '%s'"
267-
' with following error `%s`',
268-
eval_metric.metric_name,
269-
eval_case.eval_id,
270-
e,
271-
exc_info=True,
272-
)
273-
# We use an empty result.
274-
evaluation_result = EvaluationResult(
275-
overall_eval_status=EvalStatus.NOT_EVALUATED
276-
)
293+
actual_invocations = inference_result.inferences
294+
expected_invocations = eval_case.conversation
277295

278-
# Track overall score across all invocations.
279-
eval_metric_result_details = EvalMetricResultDetails(
280-
rubric_scores=evaluation_result.overall_rubric_scores
281-
)
282-
overall_eval_metric_results.append(
283-
EvalMetricResult(
284-
score=evaluation_result.overall_score,
285-
eval_status=evaluation_result.overall_eval_status,
286-
details=eval_metric_result_details,
287-
**eval_metric.model_dump(),
288-
)
289-
)
296+
# 1. Copy EvalCase level rubrics to all actual invocations.
297+
_copy_eval_case_rubrics_to_actual_invocations(eval_case, actual_invocations)
290298

291-
if (
292-
evaluation_result.overall_eval_status != EvalStatus.NOT_EVALUATED
293-
and len(evaluation_result.per_invocation_results)
294-
!= len(eval_metric_result_per_invocation)
295-
):
296-
raise ValueError(
297-
'Eval metric should return results for each invocation. Found '
298-
f'{len(evaluation_result.per_invocation_results)} results for '
299-
f'{len(eval_metric_result_per_invocation)} invocations.'
300-
)
299+
# 2. If expected invocations are present, copy invocation level
300+
# rubrics to corresponding actual invocations.
301+
_copy_invocation_rubrics_to_actual_invocations(
302+
expected_invocations, actual_invocations
303+
)
301304

302-
# Track score across individual invocations.
303-
for idx, invocation in enumerate(eval_metric_result_per_invocation):
304-
invocation_result = (
305-
evaluation_result.per_invocation_results[idx]
306-
if evaluation_result.overall_eval_status != EvalStatus.NOT_EVALUATED
307-
else PerInvocationResult(
308-
actual_invocation=invocation.actual_invocation
309-
)
310-
)
311-
eval_metric_result_details = EvalMetricResultDetails(
312-
rubric_scores=invocation_result.rubric_scores
313-
)
314-
invocation.eval_metric_results.append(
315-
EvalMetricResult(
316-
score=invocation_result.score,
317-
eval_status=invocation_result.eval_status,
318-
details=eval_metric_result_details,
319-
**eval_metric.model_dump(),
320-
)
321-
)
305+
for eval_metric in evaluate_config.eval_metrics:
306+
# Perform evaluation of the metric.
307+
await self._evaluate_metric_for_eval_case(
308+
eval_metric,
309+
eval_case,
310+
inference_result,
311+
eval_metric_result_per_invocation,
312+
overall_eval_metric_results,
313+
)
322314

323315
final_eval_status = self._generate_final_eval_status(
324316
overall_eval_metric_results
@@ -342,6 +334,84 @@ async def _evaluate_single_inference_result(
342334

343335
return (inference_result, eval_case_result)
344336

337+
async def _evaluate_metric_for_eval_case(
338+
self,
339+
eval_metric: EvalMetric,
340+
eval_case: EvalCase,
341+
inference_result: InferenceResult,
342+
eval_metric_result_per_invocation: list[EvalMetricResultPerInvocation],
343+
overall_eval_metric_results: list[EvalMetricResult],
344+
):
345+
"""Performs evaluation of a metric for a given eval case and inference result."""
346+
try:
347+
with client_label_context(EVAL_CLIENT_LABEL):
348+
evaluation_result = await self._evaluate_metric(
349+
eval_metric=eval_metric,
350+
actual_invocations=inference_result.inferences,
351+
expected_invocations=eval_case.conversation,
352+
conversation_scenario=eval_case.conversation_scenario,
353+
)
354+
except Exception as e:
355+
# We intentionally catch the Exception as we don't want failures to
356+
# affect other metric evaluation.
357+
logger.error(
358+
"Metric evaluation failed for metric `%s` for eval case id '%s'"
359+
' with following error `%s`',
360+
eval_metric.metric_name,
361+
eval_case.eval_id,
362+
e,
363+
exc_info=True,
364+
)
365+
# We use an empty result.
366+
evaluation_result = EvaluationResult(
367+
overall_eval_status=EvalStatus.NOT_EVALUATED
368+
)
369+
370+
# Track overall score across all invocations.
371+
eval_metric_result_details = EvalMetricResultDetails(
372+
rubric_scores=evaluation_result.overall_rubric_scores
373+
)
374+
overall_eval_metric_results.append(
375+
EvalMetricResult(
376+
score=evaluation_result.overall_score,
377+
eval_status=evaluation_result.overall_eval_status,
378+
details=eval_metric_result_details,
379+
**eval_metric.model_dump(),
380+
)
381+
)
382+
383+
if (
384+
evaluation_result.overall_eval_status != EvalStatus.NOT_EVALUATED
385+
and len(evaluation_result.per_invocation_results)
386+
!= len(eval_metric_result_per_invocation)
387+
):
388+
raise ValueError(
389+
'Eval metric should return results for each invocation. Found '
390+
f'{len(evaluation_result.per_invocation_results)} results for '
391+
f'{len(eval_metric_result_per_invocation)} invocations.'
392+
)
393+
394+
# Track score across individual invocations.
395+
for idx, invocation in enumerate(eval_metric_result_per_invocation):
396+
invocation_result = (
397+
evaluation_result.per_invocation_results[idx]
398+
if evaluation_result.overall_eval_status != EvalStatus.NOT_EVALUATED
399+
else PerInvocationResult(
400+
actual_invocation=invocation.actual_invocation
401+
)
402+
)
403+
eval_metric_result_details = EvalMetricResultDetails(
404+
rubric_scores=invocation_result.rubric_scores
405+
)
406+
invocation.eval_metric_results.append(
407+
EvalMetricResult(
408+
score=invocation_result.score,
409+
eval_status=invocation_result.eval_status,
410+
details=eval_metric_result_details,
411+
**eval_metric.model_dump(),
412+
)
413+
)
414+
345415
async def _evaluate_metric(
346416
self,
347417
eval_metric: EvalMetric,

src/google/adk/evaluation/rubric_based_evaluator.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -328,28 +328,67 @@ def __init__(
328328
assert self._criterion.rubrics, "Rubrics are required."
329329

330330
self._rubrics: list[Rubric] = self._criterion.rubrics
331+
self._effective_rubrics_list: Optional[list[Rubric]] = None
331332

332333
self._normalized_rubric_to_id_map = {
333334
_normalize_text(r.rubric_content.text_property): r.rubric_id
334335
for r in self._rubrics
335336
}
336337

338+
def create_effective_rubrics_list(
339+
self,
340+
invocation_rubrics: Optional[list[Rubric]],
341+
) -> None:
342+
rubrics_by_id = {}
343+
344+
def _add_rubrics(rubrics_to_add: list[Rubric], scope_name: str):
345+
for r in rubrics_to_add:
346+
if r.rubric_id in rubrics_by_id:
347+
raise ValueError(
348+
f"Rubric with rubric_id '{r.rubric_id}' already exists. Rubric"
349+
f" defined in {scope_name} conflicts with an existing rubric."
350+
)
351+
rubrics_by_id[r.rubric_id] = r
352+
353+
_add_rubrics(self._rubrics, "criterion")
354+
355+
if invocation_rubrics:
356+
_add_rubrics(invocation_rubrics, "invocation")
357+
358+
self._effective_rubrics_list = list(rubrics_by_id.values())
359+
360+
def get_effective_rubrics_list(self) -> list[Rubric]:
361+
"""Returns the effective rubrics list."""
362+
if self._effective_rubrics_list is None:
363+
raise ValueError(
364+
"Effective rubrics list not initialized. Call"
365+
" create_effective_rubrics_list() first."
366+
)
367+
return self._effective_rubrics_list
368+
337369
@override
338370
def convert_auto_rater_response_to_score(
339-
self, auto_rater_response: LlmResponse
371+
self,
372+
auto_rater_response: LlmResponse,
340373
) -> AutoRaterScore:
341374
"""Returns an AutoRaterScore generated from AutoRater's response."""
342375
response_text = get_text_from_content(auto_rater_response.content)
343376
rubric_responses = self._auto_rater_response_parser.parse(response_text)
344377
rubric_scores = []
345378

379+
normalized_rubric_to_rubric_map = {}
380+
for r in self.get_effective_rubrics_list():
381+
normalized_rubric_to_rubric_map[
382+
_normalize_text(r.rubric_content.text_property)
383+
] = r
384+
346385
for rubric_response in rubric_responses:
347-
normalized_rubric = _normalize_text(rubric_response.property_text)
348-
rubric_id = self._normalized_rubric_to_id_map.get(normalized_rubric, None)
349-
if rubric_id:
386+
normalized_rubric_text = _normalize_text(rubric_response.property_text)
387+
rubric = normalized_rubric_to_rubric_map.get(normalized_rubric_text, None)
388+
if rubric:
350389
rubric_scores.append(
351390
RubricScore(
352-
rubric_id=rubric_id,
391+
rubric_id=rubric.rubric_id,
353392
rationale=rubric_response.rationale,
354393
score=rubric_response.score,
355394
)

src/google/adk/evaluation/rubric_based_final_response_quality_v1.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from .eval_case import InvocationEvents
2626
from .eval_metrics import EvalMetric
2727
from .eval_metrics import RubricsBasedCriterion
28+
from .eval_rubrics import Rubric
2829
from .llm_as_judge_utils import get_text_from_content
2930
from .llm_as_judge_utils import get_tool_calls_and_responses_as_json_str
3031
from .llm_as_judge_utils import get_tool_declarations_as_json_str
@@ -264,15 +265,19 @@ def __init__(self, eval_metric: EvalMetric):
264265

265266
@override
266267
def format_auto_rater_prompt(
267-
self, actual_invocation: Invocation, _: Optional[Invocation]
268+
self,
269+
actual_invocation: Invocation,
270+
_: Optional[Invocation],
268271
) -> str:
269272
"""Returns the autorater prompt."""
270-
273+
self.create_effective_rubrics_list(actual_invocation.rubrics)
271274
user_input = get_text_from_content(actual_invocation.user_content)
272275
final_response = get_text_from_content(actual_invocation.final_response)
273-
rubrics = "\n* ".join(
274-
[r.rubric_content.text_property for r in self._rubrics]
275-
)
276+
277+
rubrics_text = "\n".join([
278+
f"* {r.rubric_content.text_property}"
279+
for r in self._effective_rubrics_list
280+
])
276281

277282
developer_instructions = ""
278283
tool_declarations = "Agent has no tools."
@@ -299,7 +304,7 @@ def format_auto_rater_prompt(
299304
user_input=user_input,
300305
response_steps=response_steps,
301306
final_response=final_response,
302-
rubrics=rubrics,
307+
rubrics=rubrics_text,
303308
)
304309

305310
return auto_rater_prompt

0 commit comments

Comments
 (0)