diff --git a/pydantic_evals/pydantic_evals/dataset.py b/pydantic_evals/pydantic_evals/dataset.py index 2cc051e4e..da4f0ee01 100644 --- a/pydantic_evals/pydantic_evals/dataset.py +++ b/pydantic_evals/pydantic_evals/dataset.py @@ -252,7 +252,7 @@ def __init__( async def evaluate( self, task: Callable[[InputsT], Awaitable[OutputT]], name: str | None = None, max_concurrency: int | None = None - ) -> EvaluationReport: + ) -> EvaluationReport[InputsT, OutputT, MetadataT]: """Evaluates the test cases in the dataset using the given task. This method runs the task on each case in the dataset, applies evaluators, @@ -296,7 +296,7 @@ async def _handle_case(case: Case[InputsT, OutputT, MetadataT], report_case_name def evaluate_sync( self, task: Callable[[InputsT], Awaitable[OutputT]], name: str | None = None, max_concurrency: int | None = None - ) -> EvaluationReport: + ) -> EvaluationReport[InputsT, OutputT, MetadataT]: """Evaluates the test cases in the dataset using the given task. This is a synchronous wrapper around [`evaluate`][pydantic_evals.Dataset.evaluate] provided for convenience. @@ -858,7 +858,7 @@ async def _run_task_and_evaluators( case: Case[InputsT, OutputT, MetadataT], report_case_name: str, dataset_evaluators: list[Evaluator[InputsT, OutputT, MetadataT]], -) -> ReportCase: +) -> ReportCase[InputsT, OutputT, MetadataT]: """Run a task on a case and evaluate the results. Args: @@ -908,7 +908,7 @@ async def _run_task_and_evaluators( span_id = f'{context.span_id:016x}' fallback_duration = time.time() - t0 - return ReportCase( + return ReportCase[InputsT, OutputT, MetadataT]( name=report_case_name, inputs=case.inputs, metadata=case.metadata, diff --git a/pydantic_evals/pydantic_evals/reporting/__init__.py b/pydantic_evals/pydantic_evals/reporting/__init__.py index 025b50ea9..65dae8144 100644 --- a/pydantic_evals/pydantic_evals/reporting/__init__.py +++ b/pydantic_evals/pydantic_evals/reporting/__init__.py @@ -2,14 +2,14 @@ from collections import defaultdict from collections.abc import Mapping -from dataclasses import dataclass, field +from dataclasses import dataclass from io import StringIO -from typing import Any, Callable, Literal, Protocol, TypeVar +from typing import Any, Callable, Generic, Literal, Protocol -from pydantic import BaseModel +from pydantic import BaseModel, TypeAdapter from rich.console import Console from rich.table import Table -from typing_extensions import TypedDict +from typing_extensions import TypedDict, TypeVar from pydantic_evals._utils import UNSET, Unset @@ -24,7 +24,9 @@ __all__ = ( 'EvaluationReport', + 'EvaluationReportAdapter', 'ReportCase', + 'ReportCaseAdapter', 'EvaluationRenderer', 'RenderValueConfig', 'RenderNumberConfig', @@ -35,27 +37,32 @@ EMPTY_CELL_STR = '-' EMPTY_AGGREGATE_CELL_STR = '' +InputsT = TypeVar('InputsT', default=Any) +OutputT = TypeVar('OutputT', default=Any) +MetadataT = TypeVar('MetadataT', default=Any) -class ReportCase(BaseModel): + +@dataclass +class ReportCase(Generic[InputsT, OutputT, MetadataT]): """A single case in an evaluation report.""" name: str """The name of the [case][pydantic_evals.Case].""" - inputs: Any + inputs: InputsT """The inputs to the task, from [`Case.inputs`][pydantic_evals.Case.inputs].""" - metadata: Any + metadata: MetadataT | None """Any metadata associated with the case, from [`Case.metadata`][pydantic_evals.Case.metadata].""" - expected_output: Any + expected_output: OutputT | None """The expected output of the task, from [`Case.expected_output`][pydantic_evals.Case.expected_output].""" - output: Any + output: OutputT """The output of the task execution.""" metrics: dict[str, float | int] attributes: dict[str, Any] - scores: dict[str, EvaluationResult[int | float]] = field(init=False) - labels: dict[str, EvaluationResult[str]] = field(init=False) - assertions: dict[str, EvaluationResult[bool]] = field(init=False) + scores: dict[str, EvaluationResult[int | float]] + labels: dict[str, EvaluationResult[str]] + assertions: dict[str, EvaluationResult[bool]] task_duration: float total_duration: float # includes evaluator execution time @@ -65,6 +72,9 @@ class ReportCase(BaseModel): span_id: str +ReportCaseAdapter = TypeAdapter(ReportCase[Any, Any, Any]) + + class ReportCaseAggregate(BaseModel): """A synthetic case that summarizes a set of cases.""" @@ -142,12 +152,13 @@ def _labels_averages(labels_by_name: list[dict[str, str]]) -> dict[str, dict[str ) -class EvaluationReport(BaseModel): +@dataclass +class EvaluationReport(Generic[InputsT, OutputT, MetadataT]): """A report of the results of evaluating a model on a set of cases.""" name: str """The name of the report.""" - cases: list[ReportCase] + cases: list[ReportCase[InputsT, OutputT, MetadataT]] """The cases in the report.""" def averages(self) -> ReportCaseAggregate: @@ -156,7 +167,7 @@ def averages(self) -> ReportCaseAggregate: def print( self, width: int | None = None, - baseline: EvaluationReport | None = None, + baseline: EvaluationReport[InputsT, OutputT, MetadataT] | None = None, include_input: bool = False, include_metadata: bool = False, include_expected_output: bool = False, @@ -199,7 +210,7 @@ def print( def console_table( self, - baseline: EvaluationReport | None = None, + baseline: EvaluationReport[InputsT, OutputT, MetadataT] | None = None, include_input: bool = False, include_metadata: bool = False, include_expected_output: bool = False, @@ -250,6 +261,9 @@ def __str__(self) -> str: # pragma: lax no cover return io_file.getvalue() +EvaluationReportAdapter = TypeAdapter(EvaluationReport[Any, Any, Any]) + + class RenderValueConfig(TypedDict, total=False): """A configuration for rendering a values in an Evaluation report.""" diff --git a/tests/evals/test_dataset.py b/tests/evals/test_dataset.py index 2cdefe603..003f2966f 100644 --- a/tests/evals/test_dataset.py +++ b/tests/evals/test_dataset.py @@ -33,7 +33,7 @@ class MockEvaluator(Evaluator[object, object, object]): def evaluate(self, ctx: EvaluatorContext[object, object, object]) -> EvaluatorOutput: return self.output - from pydantic_evals.reporting import ReportCase + from pydantic_evals.reporting import ReportCase, ReportCaseAdapter pytestmark = [pytest.mark.skipif(not imports_successful(), reason='pydantic-evals not installed'), pytest.mark.anyio] @@ -196,7 +196,7 @@ async def mock_task(inputs: TaskInput) -> TaskOutput: assert report is not None assert len(report.cases) == 2 - assert report.cases[0].model_dump() == snapshot( + assert ReportCaseAdapter.dump_python(report.cases[0]) == snapshot( { 'assertions': { 'correct': { @@ -248,7 +248,7 @@ async def mock_task(inputs: TaskInput) -> TaskOutput: assert report is not None assert len(report.cases) == 2 - assert report.cases[0].model_dump() == snapshot( + assert ReportCaseAdapter.dump_python(report.cases[0]) == snapshot( { 'assertions': { 'correct': { diff --git a/tests/evals/test_reports.py b/tests/evals/test_reports.py index 6ce97e427..9f7ba8a9e 100644 --- a/tests/evals/test_reports.py +++ b/tests/evals/test_reports.py @@ -12,9 +12,11 @@ from pydantic_evals.evaluators import EvaluationResult, Evaluator, EvaluatorContext from pydantic_evals.reporting import ( EvaluationReport, + EvaluationReportAdapter, RenderNumberConfig, RenderValueConfig, ReportCase, + ReportCaseAdapter, ReportCaseAggregate, ) @@ -157,7 +159,7 @@ async def test_report_case_aggregate(): async def test_report_serialization(sample_report: EvaluationReport): """Test serializing a report to dict.""" # Serialize the report - serialized = sample_report.model_dump() + serialized = EvaluationReportAdapter.dump_python(sample_report) # Check the serialized structure assert 'cases' in serialized @@ -202,7 +204,7 @@ async def test_report_with_error(mock_evaluator: Evaluator[TaskInput, TaskOutput name='error_report', ) - assert report.cases[0].model_dump() == snapshot( + assert ReportCaseAdapter.dump_python(report.cases[0]) == snapshot( { 'assertions': { 'error_evaluator': {