Skip to content

Commit f0ea520

Browse files
authored
Dataset pipelines (#384)
1 parent c3b841d commit f0ea520

4 files changed

Lines changed: 238 additions & 35 deletions

File tree

py/src/braintrust/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def is_equal(expected, output):
6262

6363
from .audit import *
6464
from .auto import auto_instrument as auto_instrument
65+
from .dataset_pipeline import *
6566
from .framework import *
6667
from .framework2 import *
6768
from .functions.invoke import *
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
from collections.abc import Awaitable, Sequence
2+
from dataclasses import dataclass
3+
from typing import Any, Generic, Literal, Protocol, TypeAlias, TypeVar
4+
5+
from typing_extensions import NotRequired, TypedDict
6+
7+
from .generated_types import ObjectReference
8+
from .logger import Metadata
9+
from .trace import Trace
10+
11+
12+
__all__ = [
13+
"DatasetPipeline",
14+
"DatasetPipelineDefinition",
15+
"DatasetPipelineRow",
16+
"DatasetPipelineScope",
17+
"DatasetPipelineSource",
18+
"DatasetPipelineTarget",
19+
"DatasetPipelineTransform",
20+
"DatasetPipelineTransformArgs",
21+
"DatasetPipelineTransformResult",
22+
"get_registered_dataset_pipelines",
23+
]
24+
25+
26+
DatasetPipelineScope: TypeAlias = Literal["span", "trace"]
27+
28+
29+
class DatasetPipelineSource(TypedDict, total=False):
30+
project_id: str
31+
project_name: str
32+
org_name: str
33+
filter: str
34+
scope: DatasetPipelineScope
35+
36+
37+
class DatasetPipelineTarget(TypedDict):
38+
dataset_name: str
39+
project_id: NotRequired[str]
40+
project_name: NotRequired[str]
41+
org_name: NotRequired[str]
42+
description: NotRequired[str]
43+
metadata: NotRequired[Metadata]
44+
45+
46+
class DatasetPipelineRow(TypedDict, total=False):
47+
id: str
48+
input: Any | None
49+
expected: Any | None
50+
tags: Sequence[str] | None
51+
metadata: Metadata | None
52+
origin: ObjectReference
53+
54+
55+
Row = TypeVar("Row", bound=DatasetPipelineRow, covariant=True)
56+
57+
58+
class DatasetPipelineTransformArgs(TypedDict, total=False):
59+
input: Any | None
60+
output: Any | None
61+
metadata: Metadata | None
62+
expected: Any | None
63+
trace: Trace
64+
65+
66+
DatasetPipelineTransformResult: TypeAlias = Row | Sequence[Row] | None
67+
68+
69+
class DatasetPipelineTransform(Protocol[Row]):
70+
def __call__(
71+
self,
72+
input: Any | None = None,
73+
output: Any | None = None,
74+
metadata: Metadata | None = None,
75+
expected: Any | None = None,
76+
trace: Trace | None = None,
77+
) -> DatasetPipelineTransformResult[Row] | Awaitable[DatasetPipelineTransformResult[Row]]: ...
78+
79+
80+
@dataclass(frozen=True)
81+
class DatasetPipelineDefinition(Generic[Row]):
82+
source: DatasetPipelineSource
83+
transform: DatasetPipelineTransform[Row]
84+
target: DatasetPipelineTarget
85+
name: str | None = None
86+
87+
88+
_DATASET_PIPELINES: list[DatasetPipelineDefinition[Any]] = []
89+
90+
91+
def get_registered_dataset_pipelines() -> list[DatasetPipelineDefinition[Any]]:
92+
return list(_DATASET_PIPELINES)
93+
94+
95+
def DatasetPipeline(
96+
name: str | None = None,
97+
*,
98+
source: DatasetPipelineSource,
99+
transform: DatasetPipelineTransform[DatasetPipelineRow],
100+
target: DatasetPipelineTarget,
101+
) -> DatasetPipelineDefinition[DatasetPipelineRow]:
102+
definition = DatasetPipelineDefinition(
103+
name=name,
104+
source=source.copy(),
105+
transform=transform,
106+
target=target.copy(),
107+
)
108+
_DATASET_PIPELINES.append(definition)
109+
return definition

py/src/braintrust/test_trace.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,35 @@ async def fetch_fn(span_type):
4141
assert len(result) == 3
4242
assert {s.span_id for s in result} == {"span-1", "span-2", "span-3"}
4343

44+
@pytest.mark.asyncio
45+
async def test_fetch_preserves_span_result_fields(self):
46+
"""Test that fetched spans preserve fields needed for full trace attachments."""
47+
mock_spans = [
48+
make_span(
49+
"span-1",
50+
"tool",
51+
expected={"answer": "ok"},
52+
error={"message": "boom"},
53+
metrics={"start": 1, "end": 2},
54+
scores={"quality": 0},
55+
tags=["debug"],
56+
)
57+
]
58+
59+
async def fetch_fn(span_type):
60+
del span_type
61+
return mock_spans
62+
63+
fetcher = CachedSpanFetcher(fetch_fn=fetch_fn)
64+
result = await fetcher.get_spans()
65+
66+
assert result[0].expected == {"answer": "ok"}
67+
assert result[0].error == {"message": "boom"}
68+
assert result[0].metrics == {"start": 1, "end": 2}
69+
assert result[0].scores == {"quality": 0}
70+
assert result[0].tags == ["debug"]
71+
assert result[0].to_dict()["error"] == {"message": "boom"}
72+
4473
@pytest.mark.asyncio
4574
async def test_fetch_specific_span_types(self):
4675
"""Test fetching specific span types when filter specified."""

0 commit comments

Comments
 (0)