Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions py/src/braintrust/dataset_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from typing_extensions import NotRequired, TypedDict

from .generated_types import ObjectReference
from .logger import Metadata
from .trace import Trace

Expand All @@ -19,7 +18,6 @@
"DatasetPipelineTransform",
"DatasetPipelineTransformArgs",
"DatasetPipelineTransformResult",
"get_registered_dataset_pipelines",
]


Expand Down Expand Up @@ -49,13 +47,13 @@ class DatasetPipelineRow(TypedDict, total=False):
expected: Any | None
tags: Sequence[str] | None
metadata: Metadata | None
origin: ObjectReference


Row = TypeVar("Row", bound=DatasetPipelineRow, covariant=True)


class DatasetPipelineTransformArgs(TypedDict, total=False):
id: str
input: Any | None
output: Any | None
metadata: Metadata | None
Expand All @@ -69,6 +67,7 @@ class DatasetPipelineTransformArgs(TypedDict, total=False):
class DatasetPipelineTransform(Protocol[Row]):
def __call__(
self,
id: str | None = None,
input: Any | None = None,
output: Any | None = None,
metadata: Metadata | None = None,
Expand All @@ -88,10 +87,6 @@ class DatasetPipelineDefinition(Generic[Row]):
_DATASET_PIPELINES: list[DatasetPipelineDefinition[Any]] = []


def get_registered_dataset_pipelines() -> list[DatasetPipelineDefinition[Any]]:
return list(_DATASET_PIPELINES)


def DatasetPipeline(
name: str | None = None,
*,
Expand Down
Loading