diff --git a/.gitignore b/.gitignore index a87404b8..984c9820 100644 --- a/.gitignore +++ b/.gitignore @@ -19,6 +19,9 @@ Thumbs.db node_modules/ tests/evals/js/eval-bun/test-data.txt +# Python +.venv/ + .bt/ # Agents diff --git a/README.md b/README.md index b912f5d6..932c72bf 100644 --- a/README.md +++ b/README.md @@ -114,6 +114,7 @@ Remove-Item -Recurse -Force (Join-Path $env:APPDATA "bt") -ErrorAction SilentlyC | `bt auth` | Authenticate with Braintrust | | `bt switch` | Switch org and project context | | `bt status` | Show current org and project context | +| `bt datasets` | Manage datasets and dataset pipelines | | `bt eval` | Run eval files (Unix only) | | `bt sql` | Run SQL queries against Braintrust | | `bt view` | View logs, traces, and spans | @@ -177,6 +178,39 @@ bt eval foo.eval.ts -- --description "Prod" --shard=1/4 - Accepted top-level record fields are `id`, `input`, `expected`, `metadata`, `tags`, and `origin` (plus the root field referenced by `--id-field`, if different). - Inputs may also be a JSON object with a top-level `rows` array, matching `bt datasets view --json`; sibling wrapper fields are ignored, and each row inside `rows` is still validated strictly. +### `bt datasets pipeline` + +Run full dataset pipelines declared with `DatasetPipeline(...)`, or stage pull/transform/push. + +```bash +# One-shot execution: discover refs, transform, and insert up to 100 new rows. +bt datasets pipeline run ./pipeline.ts --limit 100 + +# Staged execution for inspection or agent editing. +bt datasets pipeline pull ./pipeline.ts --limit 500 +bt datasets pipeline transform ./pipeline.ts +# Inspect or edit the transformed JSONL, then push to the pipeline target. +bt datasets pipeline push ./pipeline.ts + +# Python pipelines are supported too. +bt datasets pipeline run ./pipeline.py --project "" --limit 100 +``` + +Useful flags: + +- `--limit ` controls how many source refs to discover. +- `--window ` constrains source ref discovery by `created` time; defaults to `1d`. +- `--root-span-id ` restricts pulling to one or more specific root spans. +- `--root ` controls where staged artifacts are written; it defaults to `bt-sync`. A staged run writes `pulled.jsonl` and `transformed.jsonl` in the same managed directory. +- `--out` can override the managed output path for `pull` and `transform`. +- `--in` can override the latest pull artifact for `transform`, or the latest transform artifact for `push`. +- `push` reads the target from the pipeline and delegates to `bt sync push`; pass `--fresh` to restart an already completed push spec. +- `--project ` supplies the active source project when the pipeline source omits a project. +- `--source-project`, `--source-project-id`, `--source-org`, and `--source-filter` explicitly override source fields on `pull`, `transform`, and `run`. +- `--target-project`, `--target-project-id`, `--target-org`, and `--target-dataset` override target fields on `run` and `push`. +- `--max-concurrency ` controls transform concurrency. +- `--name ` selects a pipeline when the file defines more than one. + ## `bt sql` - Runs interactively on TTY by default. diff --git a/scripts/dataset-pipeline-runner.py b/scripts/dataset-pipeline-runner.py new file mode 100644 index 00000000..4d157d52 --- /dev/null +++ b/scripts/dataset-pipeline-runner.py @@ -0,0 +1,556 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import asyncio +import importlib.util +import json +import os +import socket +import sys +import traceback +import uuid +from pathlib import Path +from typing import Any + +try: + import braintrust + from braintrust.framework import call_user_fn + from braintrust.logger import _internal_get_global_state, login_to_state + from braintrust.trace import LocalTrace +except Exception as exc: # pragma: no cover - runtime guard + print( + "Unable to import the braintrust package. Please install it in your Python environment.", + file=sys.stderr, + ) + print(str(exc), file=sys.stderr) + sys.exit(1) + + +SOURCE_KEY_MAP = { + "project_id": "projectId", + "project_name": "projectName", + "org_name": "orgName", +} +TARGET_KEY_MAP = { + "project_id": "projectId", + "project_name": "projectName", + "org_name": "orgName", + "dataset_name": "datasetName", +} + +_DEFERRED_ATTACHMENT_DIR: Path | None = None + + +class DeferredJSONAttachment: + def __init__( + self, + data: Any, + *, + filename: str = "data.json", + pretty: bool = False, + ) -> None: + self._reference = deferred_json_attachment_reference(data, filename, pretty) + + @property + def reference(self) -> dict[str, Any]: + return self._reference + + def upload(self) -> dict[str, Any]: + return {"upload_status": "done", "deferred": True} + + def debug_info(self) -> dict[str, Any]: + return {"reference": self._reference} + + +def set_deferred_attachment_dir(path: str | None) -> None: + global _DEFERRED_ATTACHMENT_DIR + _DEFERRED_ATTACHMENT_DIR = Path(path).resolve() if path else None + if _DEFERRED_ATTACHMENT_DIR is not None: + _DEFERRED_ATTACHMENT_DIR.mkdir(parents=True, exist_ok=True) + + +def deferred_json_attachment_reference( + data: Any, + filename: str, + pretty: bool, +) -> dict[str, Any]: + serialized = json.dumps(data, indent=2 if pretty else None, ensure_ascii=False) + marker: dict[str, Any] = { + "type": "braintrust_deferred_attachment", + "kind": "json", + "filename": filename, + "content_type": "application/json", + } + if _DEFERRED_ATTACHMENT_DIR is None: + marker["data"] = data + marker["pretty"] = pretty + return marker + + path = _DEFERRED_ATTACHMENT_DIR / f"{uuid.uuid4()}.json" + path.write_text(serialized, encoding="utf-8") + marker["path"] = str(path) + return marker + + +def install_deferred_attachment_shims() -> None: + braintrust.JSONAttachment = DeferredJSONAttachment + import braintrust.logger as logger + + logger.JSONAttachment = DeferredJSONAttachment + + +def normalize_deferred_attachments(value: Any) -> Any: + if isinstance(value, DeferredJSONAttachment): + return value.reference + if isinstance(value, dict): + return { + key: normalize_deferred_attachments(item) + for key, item in value.items() + } + if isinstance(value, (list, tuple)): + return [normalize_deferred_attachments(item) for item in value] + return value + + +class SseWriter: + def __init__(self, sock: socket.socket): + self._socket = sock + + def send(self, event: str, payload: Any) -> None: + data = payload if isinstance(payload, str) else json.dumps(payload, separators=(",", ":")) + frame = f"event: {event}\ndata: {data}\n\n".encode("utf-8") + self._socket.sendall(frame) + + def close(self) -> None: + self._socket.close() + + +def create_sse_writer() -> SseWriter | None: + sock_path = os.getenv("BT_DATASET_PIPELINE_SSE_SOCK") + if not sock_path: + addr = os.getenv("BT_DATASET_PIPELINE_SSE_ADDR") + if not addr: + return None + if ":" not in addr: + raise ValueError("BT_DATASET_PIPELINE_SSE_ADDR must be in host:port format") + host, port_str = addr.rsplit(":", 1) + sock = socket.create_connection((host, int(port_str))) + return SseWriter(sock) + try: + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.connect(sock_path) + return SseWriter(sock) + except Exception as exc: + print(f"Failed to connect to dataset pipeline socket: {exc}", file=sys.stderr) + return None + + +def camelize_mapping(value: Any, key_map: dict[str, str]) -> Any: + if not isinstance(value, dict): + return value + return { + key_map.get(key, key): camelize_mapping(item, key_map) + for key, item in value.items() + } + + +def object_get(value: Any, name: str) -> Any: + if isinstance(value, dict): + return value.get(name) + return getattr(value, name, None) + + +def pipeline_source(pipeline: Any) -> dict[str, Any]: + source = object_get(pipeline, "source") + if not isinstance(source, dict): + raise RuntimeError("Dataset pipeline source is required.") + return source + + +def pipeline_transform(pipeline: Any) -> Any: + transform = object_get(pipeline, "transform") + if not callable(transform): + raise RuntimeError("Dataset pipeline transform must be callable.") + return transform + + +def load_pipeline_file(file: str) -> Any: + absolute = os.path.abspath(file) + cwd = os.getcwd() + file_dir = os.path.dirname(absolute) + for path in (file_dir, cwd): + if path and path not in sys.path: + sys.path.insert(0, path) + + module_name = f"_bt_dataset_pipeline_{abs(hash(absolute))}" + spec = importlib.util.spec_from_file_location(module_name, absolute) + if spec is None or spec.loader is None: + raise RuntimeError(f"Unable to load {file}.") + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +def is_pipeline(value: Any) -> bool: + checker = getattr(braintrust, "is_dataset_pipeline_definition", None) + if callable(checker) and checker(value): + return True + return ( + object_get(value, "source") is not None + and object_get(value, "target") is not None + and callable(object_get(value, "transform")) + ) + + +def collect_pipelines(module: Any) -> list[Any]: + pipelines: list[Any] = [] + seen: set[int] = set() + + registered = getattr(braintrust, "get_registered_dataset_pipelines", None) + if callable(registered): + for pipeline in registered(): + if id(pipeline) not in seen: + seen.add(id(pipeline)) + pipelines.append(pipeline) + + for value in vars(module).values(): + if is_pipeline(value) and id(value) not in seen: + seen.add(id(value)) + pipelines.append(value) + + if is_pipeline(module) and id(module) not in seen: + pipelines.append(module) + + return pipelines + + +def select_pipeline(pipelines: list[Any], name: str | None) -> Any: + if name: + matches = [ + pipeline + for pipeline in pipelines + if object_get(pipeline, "name") == name + ] + if not matches: + raise RuntimeError(f"No dataset pipeline named {json.dumps(name)} found.") + if len(matches) > 1: + raise RuntimeError( + f"Multiple dataset pipelines named {json.dumps(name)} found." + ) + return matches[0] + + if not pipelines: + raise RuntimeError("No dataset pipelines found. Did you call DatasetPipeline()?") + if len(pipelines) > 1: + names = ", ".join( + object_get(pipeline, "name") or "" for pipeline in pipelines + ) + raise RuntimeError(f"Multiple dataset pipelines found ({names}). Pass --name.") + return pipelines[0] + + +def parse_stage() -> str: + stage = os.getenv("BT_DATASET_PIPELINE_STAGE") + if stage in {"inspect", "transform"}: + return stage + raise RuntimeError("BT_DATASET_PIPELINE_STAGE must be inspect or transform.") + + +def read_request() -> dict[str, Any]: + text = sys.stdin.read().strip() + if not text: + return {} + value = json.loads(text) + if not isinstance(value, dict): + raise RuntimeError("Dataset pipeline runner request must be an object.") + return value + + +def write_response(value: Any, sse: SseWriter | None) -> None: + if sse is not None: + sse.send("response", value) + sse.close() + else: + print(json.dumps(value, separators=(",", ":"))) + + +def write_progress(sse: SseWriter | None, rows: int) -> None: + if sse is None: + return + sse.send( + "progress", + { + "type": "dataset_pipeline_progress", + "kind": "candidate", + "rows": rows, + }, + ) + + +def require_array_field(request: dict[str, Any], field: str) -> list[Any]: + value = request.get(field) + if not isinstance(value, list): + raise RuntimeError(f"Request field {field} must be an array.") + return value + + +def require_string_field(request: dict[str, Any], field: str) -> str: + value = request.get(field) + if not isinstance(value, str): + raise RuntimeError(f"Request field {field} must be a string.") + return value + + +def optional_positive_integer_field(request: dict[str, Any], field: str) -> int | None: + value = request.get(field) + if value is None: + return None + if not isinstance(value, int) or value <= 0: + raise RuntimeError(f"Request field {field} must be a positive integer.") + return value + + +def set_optional_env(name: str, value: Any) -> None: + if isinstance(value, str) and value: + os.environ[name] = value + else: + os.environ.pop(name, None) + + +def merged_source(pipeline: Any, source_override: Any) -> dict[str, Any]: + source = camelize_mapping(pipeline_source(pipeline), SOURCE_KEY_MAP) + if isinstance(source_override, dict): + return {**source, **source_override} + return source + + +def state_for_org(org_name: str | None) -> Any: + state = _internal_get_global_state() + if not org_name: + state.login() + return state + if not getattr(state, "logged_in", False): + state.login(org_name=org_name) + return state + if getattr(state, "org_name", None) == org_name: + return state + return login_to_state(org_name=org_name) + + +def ref_root_span_id(ref: Any) -> str: + if not isinstance(ref, dict) or not isinstance(ref.get("root_span_id"), str): + raise RuntimeError("Discovery ref is missing root_span_id.") + return ref["root_span_id"] + + +def ref_span_row_id(ref: Any) -> str | None: + if isinstance(ref, dict) and isinstance(ref.get("id"), str): + return ref["id"] + return None + + +def hydrate_discovery_refs( + pipeline: Any, + source_override: Any, + source_project_id: str, + refs: list[Any], +) -> list[dict[str, Any]]: + source = merged_source(pipeline, source_override) + state = state_for_org(source.get("orgName")) + candidates: list[dict[str, Any]] = [] + traces_by_root_span_id: dict[str, LocalTrace] = {} + for ref in refs: + root_span_id = ref_root_span_id(ref) + row_id = ref_span_row_id(ref) + trace = traces_by_root_span_id.get(root_span_id) + if trace is None: + trace = LocalTrace( + object_type="project_logs", + object_id=source_project_id, + root_span_id=root_span_id, + ensure_spans_flushed=None, + state=state, + ) + traces_by_root_span_id[root_span_id] = trace + candidate: dict[str, Any] = { + "trace": trace, + } + origin = ref.get("origin") if isinstance(ref, dict) else None + if isinstance(origin, dict): + candidate["origin"] = origin + if row_id: + candidate["id"] = row_id + candidates.append(candidate) + return candidates + + +def span_attr(span: Any, name: str) -> Any: + if isinstance(span, dict): + return span.get(name) + return getattr(span, name, None) + + +async def source_row_for_candidate(candidate: dict[str, Any]) -> Any | None: + row_id = candidate.get("id") + if not isinstance(row_id, str): + return None + + trace = candidate["trace"] + spans = await trace.get_spans(include_scorers=True) + for span in spans: + if row_id in {span_attr(span, "id"), span_attr(span, "span_id")}: + return span + raise RuntimeError(f"Source span row {row_id!r} was not found in hydrated trace.") + + +async def transform_args_for_candidate(candidate: dict[str, Any]) -> dict[str, Any]: + row = await source_row_for_candidate(candidate) + return { + "input": span_attr(row, "input"), + "output": span_attr(row, "output"), + "expected": span_attr(row, "expected"), + "metadata": span_attr(row, "metadata"), + "trace": candidate["trace"], + } + + +def normalize_transform_result(result: Any) -> list[Any]: + if result is None: + return [] + if isinstance(result, list): + return result + return [result] + + +def candidate_fallback_id(candidate: dict[str, Any]) -> str | None: + row_id = candidate.get("id") + if isinstance(row_id, str): + return row_id + trace = candidate.get("trace") + config = trace.get_configuration() if hasattr(trace, "get_configuration") else None + if isinstance(config, dict) and isinstance(config.get("root_span_id"), str): + return config["root_span_id"] + return None + + +def with_pipeline_defaults( + row: Any, + candidate: dict[str, Any], + row_index: int | None, +) -> dict[str, Any]: + row = normalize_deferred_attachments(row) + if not isinstance(row, dict): + raise RuntimeError("Dataset pipeline transform must return an object row.") + output = dict(row) + fallback_id = candidate_fallback_id(candidate) + if "id" not in output and fallback_id: + output["id"] = fallback_id if row_index is None else f"{fallback_id}:{row_index}" + if "origin" not in output and "origin" in candidate: + output["origin"] = candidate["origin"] + return output + + +async def transform_refs( + pipeline: Any, + source_override: Any, + source_project_id: str, + refs: list[Any], + max_concurrency: int = 16, + sse: SseWriter | None = None, +) -> list[dict[str, Any]]: + if max_concurrency <= 0: + raise RuntimeError("maxConcurrency must be a positive integer.") + transform = pipeline_transform(pipeline) + candidates = hydrate_discovery_refs(pipeline, source_override, source_project_id, refs) + transformed_rows: list[list[dict[str, Any]]] = [[] for _ in candidates] + semaphore = asyncio.Semaphore(max_concurrency) + + async def run_one(index: int, candidate: dict[str, Any]) -> None: + async with semaphore: + transform_args = await transform_args_for_candidate(candidate) + result = await call_user_fn( + asyncio.get_running_loop(), + transform, + **transform_args, + ) + rows = normalize_transform_result(result) + transformed_rows[index] = [ + with_pipeline_defaults( + row, + candidate, + row_index if len(rows) > 1 else None, + ) + for row_index, row in enumerate(rows) + ] + write_progress(sse, len(transformed_rows[index])) + + await asyncio.gather( + *(run_one(index, candidate) for index, candidate in enumerate(candidates)) + ) + return [row for rows in transformed_rows for row in rows] + + +async def main() -> None: + if len(sys.argv) < 2: + raise RuntimeError("Pipeline file is required.") + + stage = parse_stage() + if stage == "transform": + install_deferred_attachment_shims() + + module = load_pipeline_file(sys.argv[1]) + pipeline = select_pipeline( + collect_pipelines(module), + os.getenv("BT_DATASET_PIPELINE_NAME") or None, + ) + sse = create_sse_writer() + + if stage == "inspect": + write_response( + { + "name": object_get(pipeline, "name"), + "source": camelize_mapping(object_get(pipeline, "source"), SOURCE_KEY_MAP), + "target": camelize_mapping(object_get(pipeline, "target"), TARGET_KEY_MAP), + }, + sse, + ) + elif stage == "transform": + request = read_request() + attachment_dir = request.get("attachmentDir") + if attachment_dir is not None and not isinstance(attachment_dir, str): + raise RuntimeError("Request field attachmentDir must be a string.") + set_deferred_attachment_dir(attachment_dir) + refs = require_array_field(request, "refs") + source_project_id = require_string_field(request, "sourceProjectId") + source_override = ( + request.get("source") if isinstance(request.get("source"), dict) else None + ) + source_for_env = ( + source_override + if isinstance(source_override, dict) + else camelize_mapping(object_get(pipeline, "source"), SOURCE_KEY_MAP) + ) + set_optional_env( + "BT_DATASET_PIPELINE_SOURCE_ORG_NAME", + source_for_env.get("orgName") if isinstance(source_for_env, dict) else None, + ) + rows = await transform_refs( + pipeline, + source_override, + source_project_id, + refs, + optional_positive_integer_field(request, "maxConcurrency") or 16, + sse, + ) + write_response({"candidates": len(refs), "rowCount": len(rows), "rows": rows}, sse) + else: + raise RuntimeError(f"Unsupported dataset pipeline stage: {stage}") + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except Exception: + traceback.print_exc(file=sys.stderr) + sys.exit(1) diff --git a/scripts/dataset-pipeline-runner.ts b/scripts/dataset-pipeline-runner.ts new file mode 100644 index 00000000..ce779554 --- /dev/null +++ b/scripts/dataset-pipeline-runner.ts @@ -0,0 +1,803 @@ +import { createRequire } from "node:module"; +import { randomUUID } from "node:crypto"; +import fs from "node:fs"; +import net from "node:net"; +import { pathToFileURL } from "node:url"; +import path from "node:path"; + +type PipelineSource = { + projectName?: string; + projectId?: string; + orgName?: string; + filter?: string; + scope?: "span" | "trace"; +}; + +type PipelineTarget = { + projectName?: string; + projectId?: string; + orgName?: string; + datasetName?: string; + description?: string; + metadata?: Record; +}; + +type DatasetPipelineDefinition = { + name?: string; + source?: PipelineSource; + target?: PipelineTarget; + transform?: ( + args: DatasetPipelineTransformArgs, + ) => unknown | Promise; +}; + +type DatasetPipelineTransformArgs = { + input?: unknown; + output?: unknown; + expected?: unknown; + metadata?: unknown; + trace: unknown; +}; + +type BraintrustModule = { + DatasetPipeline?: ( + definition: DatasetPipelineDefinition, + ) => DatasetPipelineDefinition; + getRegisteredDatasetPipelines?: () => DatasetPipelineDefinition[]; + isDatasetPipelineDefinition?: ( + value: unknown, + ) => value is DatasetPipelineDefinition; + LocalTrace?: new (options: { + objectType: "project_logs"; + objectId: string; + rootSpanId: string; + state: unknown; + }) => unknown; + _internalGetGlobalState?: () => BraintrustState; + loginToState?: (options: { orgName: string }) => Promise; + JSONAttachment?: new ( + data: unknown, + options?: { filename?: string; pretty?: boolean }, + ) => unknown; + default?: BraintrustModule; +}; + +type BraintrustState = { + loggedIn?: boolean; + orgName?: string; + login: (options: Record) => Promise; +}; + +type DiscoveryRef = { + root_span_id?: unknown; + id?: unknown; +}; + +type HydratedCandidate = { + trace: unknown; + id?: string; + origin?: { + object_type: "project_logs"; + object_id: string; + id: string; + created?: string; + _xact_id?: string; + }; +}; + +type Stage = "inspect" | "transform"; + +type SseWriter = { + send: (event: string, payload: unknown) => void; + close: () => void; +}; + +type DeferredAttachmentReference = { + type: "braintrust_deferred_attachment"; + kind: "json"; + filename: string; + content_type: "application/json"; + path?: string; + data?: unknown; + pretty?: boolean; +}; + +type DeferredJsonAttachmentHook = ( + data: unknown, + options?: { filename?: string; pretty?: boolean }, +) => DeferredJSONAttachment; + +declare global { + // Used by ESM imports of hook-aware Braintrust SDKs where named exports cannot + // be monkey-patched by the runner. + var __BT_DATASET_PIPELINE_DEFER_JSON_ATTACHMENT__: + | DeferredJsonAttachmentHook + | undefined; +} + +let deferredAttachmentDir: string | null = null; + +function isObject(value: unknown): value is Record { + return typeof value === "object" && value !== null; +} + +function normalizeBraintrustModule(value: unknown): BraintrustModule { + if (isObject(value) && "default" in value && isObject(value.default)) { + return value.default as BraintrustModule; + } + if (isObject(value)) { + return value as BraintrustModule; + } + throw new Error("Unable to load braintrust module."); +} + +function setDeferredAttachmentDir(value: unknown): void { + if (value === undefined || value === null) { + deferredAttachmentDir = null; + return; + } + if (typeof value !== "string") { + throw new Error("Request field attachmentDir must be a string."); + } + deferredAttachmentDir = path.resolve(value); +} + +function deferredJsonAttachmentReference( + data: unknown, + options?: { filename?: string; pretty?: boolean }, +): DeferredAttachmentReference { + const filename = options?.filename ?? "data.json"; + const pretty = options?.pretty === true; + const reference: DeferredAttachmentReference = { + type: "braintrust_deferred_attachment", + kind: "json", + filename, + content_type: "application/json", + }; + + if (deferredAttachmentDir) { + fs.mkdirSync(deferredAttachmentDir, { recursive: true }); + const attachmentPath = path.join( + deferredAttachmentDir, + `${randomUUID()}.json`, + ); + const serialized = JSON.stringify(data, null, pretty ? 2 : undefined); + fs.writeFileSync( + attachmentPath, + serialized === undefined ? "null" : serialized, + "utf8", + ); + reference.path = attachmentPath; + } else { + reference.data = data; + if (pretty) { + reference.pretty = true; + } + } + + return reference; +} + +class DeferredJSONAttachment { + readonly reference: DeferredAttachmentReference; + + constructor( + data: unknown, + options?: { filename?: string; pretty?: boolean }, + ) { + this.reference = deferredJsonAttachmentReference(data, options); + } + + async upload(): Promise> { + return { upload_status: "done", deferred: true }; + } + + async data(): Promise { + const serialized = + this.reference.path !== undefined + ? fs.readFileSync(this.reference.path, "utf8") + : (JSON.stringify( + this.reference.data, + null, + this.reference.pretty === true ? 2 : undefined, + ) ?? "null"); + return new Blob([serialized], { type: this.reference.content_type }); + } + + debugInfo(): Record { + return { reference: this.reference }; + } +} + +function setModuleExport(target: unknown, name: string, value: unknown): void { + if (!isObject(target)) { + return; + } + try { + Object.defineProperty(target, name, { + value, + configurable: true, + enumerable: true, + writable: true, + }); + } catch { + try { + target[name] = value; + } catch {} + } +} + +function installDeferredAttachmentShims(braintrust: BraintrustModule): void { + globalThis.__BT_DATASET_PIPELINE_DEFER_JSON_ATTACHMENT__ = (data, options) => + new DeferredJSONAttachment(data, options); + setModuleExport(braintrust, "JSONAttachment", DeferredJSONAttachment); + setModuleExport(braintrust.default, "JSONAttachment", DeferredJSONAttachment); +} + +function normalizeDeferredAttachments(value: unknown): unknown { + if (value instanceof DeferredJSONAttachment) { + return value.reference; + } + if (Array.isArray(value)) { + return value.map((item) => normalizeDeferredAttachments(item)); + } + if (!isObject(value)) { + return value; + } + + const prototype = Object.getPrototypeOf(value); + if (prototype !== Object.prototype && prototype !== null) { + return value; + } + + return Object.fromEntries( + Object.entries(value).map(([key, item]) => [ + key, + normalizeDeferredAttachments(item), + ]), + ); +} + +function resolveBraintrustPath(pipelineFile: string): string { + const file = path.resolve(process.cwd(), pipelineFile); + try { + const require = createRequire(pathToFileURL(file).href); + return require.resolve("braintrust"); + } catch {} + + try { + const require = createRequire(process.cwd() + "/"); + return require.resolve("braintrust"); + } catch { + throw new Error( + "Unable to resolve the `braintrust` package. Please install it in your project.", + ); + } +} + +async function loadBraintrust(pipelineFile: string): Promise { + const cjsPath = resolveBraintrustPath(pipelineFile); + const cjsUrl = pathToFileURL(cjsPath).href; + + try { + return normalizeBraintrustModule(await import(cjsUrl)); + } catch {} + + const esmPath = cjsPath.replace(/\.js$/, ".mjs"); + if (esmPath !== cjsPath && fs.existsSync(esmPath)) { + try { + return normalizeBraintrustModule( + await import(pathToFileURL(esmPath).href), + ); + } catch {} + } + + const require = createRequire(cjsUrl); + return normalizeBraintrustModule(require(cjsPath)); +} + +async function loadPipelineFile(file: string): Promise { + const absolute = path.resolve(process.cwd(), file); + const fileUrl = pathToFileURL(absolute).href; + try { + return await import(fileUrl); + } catch (importErr) { + try { + const require = createRequire(fileUrl); + return require(absolute); + } catch (requireErr) { + throw new Error( + `Failed to load ${file}: import failed with ${formatError(importErr)}; require failed with ${formatError(requireErr)}`, + ); + } + } +} + +function formatError(err: unknown): string { + return err instanceof Error ? err.message : String(err); +} + +function collectPipelines( + braintrust: BraintrustModule, + loadedModule: unknown, +): DatasetPipelineDefinition[] { + const pipelines = new Set(); + const isPipeline = (value: unknown): value is DatasetPipelineDefinition => + (braintrust.isDatasetPipelineDefinition?.(value) ?? false) || + (isObject(value) && + isObject(value.source) && + isObject(value.target) && + typeof value.transform === "function"); + + for (const pipeline of braintrust.getRegisteredDatasetPipelines?.() ?? []) { + pipelines.add(pipeline); + } + + if (isObject(loadedModule)) { + for (const value of Object.values(loadedModule)) { + if (isPipeline(value)) { + pipelines.add(value); + } + } + } + + if (isPipeline(loadedModule)) { + pipelines.add(loadedModule); + } + + return [...pipelines]; +} + +function selectPipeline( + pipelines: DatasetPipelineDefinition[], + name: string | undefined, +): DatasetPipelineDefinition { + if (name) { + const matches = pipelines.filter((pipeline) => pipeline.name === name); + if (matches.length === 0) { + throw new Error( + `No dataset pipeline named ${JSON.stringify(name)} found.`, + ); + } + if (matches.length > 1) { + throw new Error( + `Multiple dataset pipelines named ${JSON.stringify(name)} found.`, + ); + } + return matches[0]; + } + + if (pipelines.length === 0) { + throw new Error( + "No dataset pipelines found. Did you call DatasetPipeline()?", + ); + } + if (pipelines.length > 1) { + const names = pipelines + .map((pipeline) => pipeline.name ?? "") + .join(", "); + throw new Error( + `Multiple dataset pipelines found (${names}). Pass --name.`, + ); + } + return pipelines[0]; +} + +function parseStage(): Stage { + const value = process.env.BT_DATASET_PIPELINE_STAGE; + if (value === "inspect" || value === "transform") { + return value; + } + throw new Error("BT_DATASET_PIPELINE_STAGE must be inspect or transform."); +} + +async function readRequest(): Promise { + const chunks: Buffer[] = []; + for await (const chunk of process.stdin) { + chunks.push(Buffer.isBuffer(chunk) ? chunk : Buffer.from(String(chunk))); + } + const text = Buffer.concat(chunks).toString("utf8").trim(); + return text.length > 0 ? JSON.parse(text) : {}; +} + +function writeResponse(value: unknown, sse: SseWriter | null): void { + if (sse) { + sse.send("response", value); + sse.close(); + } else { + process.stdout.write(`${JSON.stringify(value)}\n`); + } +} + +function serializeSseEvent(event: { event?: string; data: string }): string { + return ( + Object.entries(event) + .filter(([_key, value]) => value !== undefined) + .map(([key, value]) => `${key}: ${value}`) + .join("\n") + "\n\n" + ); +} + +function createSseWriter(): SseWriter | null { + const sock = process.env.BT_DATASET_PIPELINE_SSE_SOCK; + const addr = process.env.BT_DATASET_PIPELINE_SSE_ADDR; + if (!sock && !addr) { + return null; + } + let socket: net.Socket; + if (sock) { + socket = net.createConnection({ path: sock }); + } else if (addr) { + const [host, portStr] = addr.split(":"); + const port = Number(portStr); + if (!host || !Number.isFinite(port)) { + throw new Error(`Invalid BT_DATASET_PIPELINE_SSE_ADDR: ${addr}`); + } + socket = net.createConnection({ host, port }); + socket.setNoDelay(true); + } else { + return null; + } + socket.on("error", (err) => { + console.error( + `Failed to connect to dataset pipeline SSE endpoint: ${ + err instanceof Error ? err.message : String(err) + }`, + ); + }); + return { + send: (event: string, payload: unknown) => { + if (!socket.writable) { + return; + } + const data = + typeof payload === "string" ? payload : JSON.stringify(payload); + socket.write(serializeSseEvent({ event, data })); + }, + close: () => { + socket.end(); + }, + }; +} + +function writeProgress(sse: SseWriter | null, rows: number): void { + if (!sse) { + return; + } + sse.send("progress", { + type: "dataset_pipeline_progress", + kind: "candidate", + rows, + }); +} + +function requireArrayField(request: unknown, field: string): unknown[] { + if (!isObject(request) || !Array.isArray(request[field])) { + throw new Error(`Request field ${field} must be an array.`); + } + return request[field] as unknown[]; +} + +function requireStringField(request: unknown, field: string): string { + if (!isObject(request) || typeof request[field] !== "string") { + throw new Error(`Request field ${field} must be a string.`); + } + return request[field] as string; +} + +function optionalPositiveIntegerField( + request: unknown, + field: string, +): number | undefined { + if (!isObject(request) || request[field] === undefined) { + return undefined; + } + const value = request[field]; + if (!Number.isInteger(value) || (value as number) <= 0) { + throw new Error(`Request field ${field} must be a positive integer.`); + } + return value as number; +} + +function setOptionalEnv(name: string, value: unknown): void { + if (typeof value === "string" && value.length > 0) { + process.env[name] = value; + } else { + delete process.env[name]; + } +} + +function requirePipelineSource( + pipeline: DatasetPipelineDefinition, + sourceOverride?: PipelineSource, +): PipelineSource { + if (!isObject(pipeline.source)) { + throw new Error("Dataset pipeline source is required."); + } + return { ...pipeline.source, ...(sourceOverride ?? {}) }; +} + +function requireBraintrustRuntime(braintrust: BraintrustModule) { + if ( + !braintrust.LocalTrace || + !braintrust._internalGetGlobalState || + !braintrust.loginToState + ) { + throw new Error( + "The installed braintrust package does not include dataset pipeline runtime support.", + ); + } +} + +async function stateForOrg( + braintrust: BraintrustModule, + orgName: string | undefined, +): Promise { + if (!braintrust._internalGetGlobalState || !braintrust.loginToState) { + throw new Error("The installed braintrust package cannot authenticate."); + } + const state = braintrust._internalGetGlobalState(); + if (!orgName) { + await state.login({}); + return state; + } + if (!state.loggedIn) { + await state.login({ orgName }); + return state; + } + if (state.orgName === orgName) { + return state; + } + return braintrust.loginToState({ orgName }); +} + +function refRootSpanId(ref: unknown): string { + if (!isObject(ref) || typeof ref.root_span_id !== "string") { + throw new Error("Discovery ref is missing root_span_id."); + } + return ref.root_span_id; +} + +function refSpanRowId(ref: DiscoveryRef): string | undefined { + return typeof ref.id === "string" ? ref.id : undefined; +} + +async function hydrateDiscoveryRefs( + braintrust: BraintrustModule, + pipeline: DatasetPipelineDefinition, + sourceOverride: PipelineSource | undefined, + sourceProjectId: string, + refs: unknown[], +): Promise { + requireBraintrustRuntime(braintrust); + const source = requirePipelineSource(pipeline, sourceOverride); + const state = await stateForOrg(braintrust, source.orgName); + const tracesByRootSpanId = new Map(); + return refs.map((ref) => { + const rootSpanId = refRootSpanId(ref); + const id = refSpanRowId(ref as DiscoveryRef); + let trace = tracesByRootSpanId.get(rootSpanId); + if (!trace) { + trace = new braintrust.LocalTrace!({ + objectType: "project_logs", + objectId: sourceProjectId, + rootSpanId, + state, + }); + tracesByRootSpanId.set(rootSpanId, trace); + } + const origin = + isObject(ref) && isObject(ref.origin) + ? (ref.origin as HydratedCandidate["origin"]) + : undefined; + return { + trace, + ...(id ? { id } : {}), + ...(origin ? { origin } : {}), + }; + }); +} + +function spanAttr(row: unknown, name: string): unknown { + return isObject(row) ? row[name] : undefined; +} + +async function sourceRowForCandidate( + candidate: HydratedCandidate, +): Promise { + if (!candidate.id) { + return undefined; + } + const trace = candidate.trace; + if (!isObject(trace) || typeof trace.getSpans !== "function") { + throw new Error("Hydrated trace does not support getSpans()."); + } + const spans = await trace.getSpans({ includeScorers: true }); + if (!Array.isArray(spans)) { + throw new Error("Hydrated trace getSpans() did not return an array."); + } + const row = spans.find( + (span) => + spanAttr(span, "id") === candidate.id || + spanAttr(span, "span_id") === candidate.id, + ); + if (!row) { + throw new Error( + `Source span row ${JSON.stringify(candidate.id)} was not found in hydrated trace.`, + ); + } + return row; +} + +async function transformArgsForCandidate( + candidate: HydratedCandidate, +): Promise { + const row = await sourceRowForCandidate(candidate); + return { + input: spanAttr(row, "input"), + output: spanAttr(row, "output"), + expected: spanAttr(row, "expected"), + metadata: spanAttr(row, "metadata"), + trace: candidate.trace, + }; +} + +function normalizeTransformResult(result: unknown): unknown[] { + if (result == null) { + return []; + } + return Array.isArray(result) ? result : [result]; +} + +function candidateFallbackId(candidate: HydratedCandidate): string | undefined { + if (candidate.id) { + return candidate.id; + } + const trace = candidate.trace; + if ( + isObject(trace) && + typeof trace.getConfiguration === "function" && + isObject(trace.getConfiguration()) + ) { + const config = trace.getConfiguration() as Record; + return typeof config.root_span_id === "string" + ? config.root_span_id + : undefined; + } + return undefined; +} + +function withPipelineDefaults( + row: unknown, + candidate: HydratedCandidate, + rowIndex: number | undefined, +): unknown { + const normalizedRow = normalizeDeferredAttachments(row); + if (!isObject(normalizedRow)) { + throw new Error("Dataset pipeline transform must return an object row."); + } + const fallbackId = candidateFallbackId(candidate); + return { + ...normalizedRow, + ...(normalizedRow.id === undefined && fallbackId + ? { + id: rowIndex === undefined ? fallbackId : `${fallbackId}:${rowIndex}`, + } + : {}), + ...(normalizedRow.origin === undefined && candidate.origin + ? { origin: candidate.origin } + : {}), + }; +} + +async function transformRefs( + braintrust: BraintrustModule, + pipeline: DatasetPipelineDefinition, + sourceOverride: PipelineSource | undefined, + sourceProjectId: string, + refs: unknown[], + maxConcurrency = 16, + sse: SseWriter | null = null, +): Promise { + if (!Number.isInteger(maxConcurrency) || maxConcurrency <= 0) { + throw new Error("maxConcurrency must be a positive integer."); + } + if (typeof pipeline.transform !== "function") { + throw new Error("Dataset pipeline transform must be a function."); + } + const candidates = await hydrateDiscoveryRefs( + braintrust, + pipeline, + sourceOverride, + sourceProjectId, + refs, + ); + const transformedRows: unknown[][] = new Array(candidates.length); + let nextIndex = 0; + + async function worker() { + while (nextIndex < candidates.length) { + const index = nextIndex++; + const candidate = candidates[index]; + const args = await transformArgsForCandidate(candidate); + const result = await pipeline.transform!(args); + const rows = normalizeTransformResult(result); + transformedRows[index] = rows.map((row, rowIndex) => + withPipelineDefaults( + row, + candidate, + rows.length > 1 ? rowIndex : undefined, + ), + ); + writeProgress(sse, transformedRows[index].length); + } + } + + const workerCount = Math.min(maxConcurrency, Math.max(candidates.length, 1)); + await Promise.all(Array.from({ length: workerCount }, () => worker())); + return transformedRows.flat(); +} + +async function main() { + const pipelineFile = process.argv[2]; + if (!pipelineFile) { + throw new Error("Pipeline file is required."); + } + + const stage = parseStage(); + const braintrust = await loadBraintrust(pipelineFile); + if (stage === "transform") { + installDeferredAttachmentShims(braintrust); + } + const loadedModule = await loadPipelineFile(pipelineFile); + const pipeline = selectPipeline( + collectPipelines(braintrust, loadedModule), + process.env.BT_DATASET_PIPELINE_NAME || undefined, + ); + const sse = createSseWriter(); + + if (stage === "inspect") { + writeResponse( + { + name: pipeline.name, + source: pipeline.source, + target: pipeline.target, + }, + sse, + ); + } else if (stage === "transform") { + const request = await readRequest(); + setDeferredAttachmentDir(isObject(request) ? request.attachmentDir : null); + const refs = requireArrayField(request, "refs"); + const sourceProjectId = requireStringField(request, "sourceProjectId"); + const sourceOverride = + isObject(request) && isObject(request.source) + ? (request.source as PipelineSource) + : undefined; + const sourceForEnv = sourceOverride ?? pipeline.source; + setOptionalEnv( + "BT_DATASET_PIPELINE_SOURCE_ORG_NAME", + isObject(sourceForEnv) ? sourceForEnv.orgName : undefined, + ); + const rows = await transformRefs( + braintrust, + pipeline, + sourceOverride, + sourceProjectId, + refs, + optionalPositiveIntegerField(request, "maxConcurrency"), + sse, + ); + writeResponse( + { candidates: refs.length, rowCount: rows.length, rows }, + sse, + ); + } else { + throw new Error(`Unsupported dataset pipeline stage: ${stage}`); + } +} + +main().catch((err) => { + console.error(err instanceof Error ? err.stack || err.message : String(err)); + process.exit(1); +}); diff --git a/src/auth.rs b/src/auth.rs index 6bddbad4..4cfe9c47 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -941,6 +941,18 @@ pub async fn resolved_auth_env(base: &BaseArgs) -> Result> Ok(envs) } +pub async fn resolved_runner_env(base: &BaseArgs) -> Result> { + let mut envs = resolved_auth_env(base).await?; + let project = base + .project + .clone() + .or_else(|| crate::config::load().ok().and_then(|c| c.project)); + if let Some(project) = project { + envs.push(("BRAINTRUST_DEFAULT_PROJECT".to_string(), project)); + } + Ok(envs) +} + fn resolve_profile_for_org<'a>(org: &str, store: &'a AuthStore) -> Option<&'a str> { if store.profiles.contains_key(org) { return Some( diff --git a/src/datasets/api.rs b/src/datasets/api.rs index f98a99ff..1364cb12 100644 --- a/src/datasets/api.rs +++ b/src/datasets/api.rs @@ -186,6 +186,16 @@ pub async fn create_dataset( project_id: &str, name: &str, description: Option<&str>, +) -> Result { + create_dataset_with_metadata(client, project_id, name, description, None).await +} + +pub async fn create_dataset_with_metadata( + client: &ApiClient, + project_id: &str, + name: &str, + description: Option<&str>, + metadata: Option<&Value>, ) -> Result { let mut body = serde_json::json!({ "name": name, @@ -195,6 +205,9 @@ pub async fn create_dataset( if let Some(description) = description.filter(|description| !description.is_empty()) { body["description"] = serde_json::Value::String(description.to_string()); } + if let Some(metadata) = metadata { + body["metadata"] = metadata.clone(); + } client.post("/v1/dataset", &body).await } diff --git a/src/datasets/mod.rs b/src/datasets/mod.rs index eb59226d..0993ed53 100644 --- a/src/datasets/mod.rs +++ b/src/datasets/mod.rs @@ -15,6 +15,7 @@ pub(crate) mod api; mod create; mod delete; mod list; +mod pipeline; mod records; mod update; mod utils; @@ -107,6 +108,8 @@ enum DatasetsCommands { View(ViewArgs), /// Delete a dataset Delete(DeleteArgs), + /// Run full dataset pipelines, or stage pull/transform/push + Pipeline(pipeline::PipelineArgs), } #[derive(Debug, Clone, Args)] @@ -165,7 +168,7 @@ struct ViewArgs { )] verbose: bool, - /// Fetch full row values instead of BTQL previews. + /// Load full row values instead of BTQL previews. #[arg( long, env = "BT_DATASETS_VIEW_FULL", @@ -253,10 +256,17 @@ pub(crate) async fn select_dataset_interactive( } pub async fn run(base: BaseArgs, args: DatasetsArgs) -> Result<()> { - let read_only = datasets_command_is_read_only(args.command.as_ref()); + let command = match args.command { + Some(DatasetsCommands::Pipeline(pipeline_args)) => { + return pipeline::run(base, pipeline_args).await; + } + command => command, + }; + + let read_only = datasets_command_is_read_only(command.as_ref()); let ctx = resolve_project_command_context_with_auth_mode(&base, read_only).await?; - match args.command { + match command { None | Some(DatasetsCommands::List) => list::run(&ctx, base.json).await, Some(DatasetsCommands::Create(create_args)) => { create::run( @@ -296,6 +306,7 @@ pub async fn run(base: BaseArgs, args: DatasetsArgs) -> Result<()> { Some(DatasetsCommands::Delete(delete_args)) => { delete::run(&ctx, delete_args.name(), delete_args.force).await } + Some(DatasetsCommands::Pipeline(_)) => unreachable!("pipeline handled before context"), } } diff --git a/src/datasets/pipeline.rs b/src/datasets/pipeline.rs new file mode 100644 index 00000000..d32c7914 --- /dev/null +++ b/src/datasets/pipeline.rs @@ -0,0 +1,2667 @@ +use std::collections::HashMap; +use std::fs; +use std::io::{self, BufRead, BufReader, IsTerminal, Read, Write}; +use std::path::{Path, PathBuf}; +use std::process::{Command, Stdio}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::time::Duration; + +use anyhow::{bail, Context, Result}; +use clap::{Args, Subcommand}; +use indicatif::{ProgressBar, ProgressStyle}; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; + +use crate::args::BaseArgs; +use crate::auth::{login, resolved_runner_env, LoginContext}; +use crate::http::ApiClient; +use crate::js_runner::{build_js_runner_command, materialize_runner_script}; +use crate::projects::api::{create_project, get_project_by_name, Project}; +use crate::python_runner; +use crate::runner_sse; +use crate::source_language::{classify_runtime_extension, SourceLanguage}; +use crate::sync::discovery::{ + discover_project_log_refs, ProjectLogRefDiscoveryResult, ProjectLogRefScope, +}; +use crate::sync::{ + artifact_base_dir, artifact_spec_dir, create_jsonl_file_writer, epoch_seconds, read_json_file, + read_jsonl_values, stable_spec_hash, write_json_atomic, write_jsonl_value, SyncPushFileArgs, +}; +use crate::utils::parse_duration_to_seconds; +use tokio::sync::mpsc; + +use super::{api as datasets_api, records, utils, ResolvedContext}; + +const RUNNER_FILE: &str = "dataset-pipeline-runner.ts"; +const RUNNER_SOURCE: &str = include_str!("../../scripts/dataset-pipeline-runner.ts"); +const PY_RUNNER_FILE: &str = "dataset-pipeline-runner.py"; +const PY_RUNNER_SOURCE: &str = include_str!("../../scripts/dataset-pipeline-runner.py"); +const PIPELINE_ARTIFACT_OBJECT_TYPE: &str = "dataset_pipeline"; +const PIPELINE_ARTIFACT_SCHEMA_VERSION: u32 = 1; + +#[derive(Debug, Clone, Args)] +#[command(after_help = "\ +Use `run` to run the whole pipeline. + +For staged workflows, run `pull`, then `transform`, inspect or edit the transformed JSONL, then upload it with: + bt datasets pipeline push ./pipeline.ts + +`push` reads the pipeline target and delegates to `bt sync push`. +")] +pub struct PipelineArgs { + #[command(subcommand)] + command: PipelineCommands, +} + +#[derive(Debug, Clone, Subcommand)] +enum PipelineCommands { + /// Pull, transform, and insert dataset rows + Run(PipelineRunArgs), + /// Pull source trace/span refs to JSONL + Pull(PipelinePullArgs), + /// Transform candidate JSONL into proposed dataset row JSONL + Transform(PipelineTransformArgs), + /// Push transformed dataset rows to the pipeline target + Push(PipelinePushArgs), +} + +#[derive(Debug, Clone, Args)] +struct PipelineRunnerArgs { + /// Dataset pipeline file to execute + #[arg(value_name = "PIPELINE")] + pipeline: PathBuf, + + /// Pipeline name, required when the file defines multiple pipelines + #[arg(long)] + name: Option, + + /// Runner binary (e.g. tsx, vite-node, ts-node, python) + #[arg( + long, + short = 'r', + env = "BT_DATASET_PIPELINE_RUNNER", + value_name = "RUNNER" + )] + runner: Option, +} + +#[derive(Debug, Clone, Args)] +struct PipelineSourceArgs { + /// Override the source project name from the pipeline file + #[arg(long = "source-project")] + source_project: Option, + + /// Override the source project id from the pipeline file + #[arg(long = "source-project-id")] + source_project_id: Option, + + /// Override the source org name from the pipeline file + #[arg(long = "source-org")] + source_org: Option, + + /// Override the source filter from the pipeline file + #[arg(long = "source-filter")] + source_filter: Option, +} + +#[derive(Debug, Clone, Args)] +struct PipelineTargetArgs { + /// Override the target project name from the pipeline file + #[arg(long = "target-project")] + target_project: Option, + + /// Override the target project id from the pipeline file + #[arg(long = "target-project-id")] + target_project_id: Option, + + /// Override the target org name from the pipeline file + #[arg(long = "target-org")] + target_org: Option, + + /// Override the target dataset name from the pipeline file + #[arg(long = "target-dataset")] + target_dataset: Option, +} + +#[derive(Debug, Clone, Args)] +struct PipelinePullOptions { + /// Maximum number of source refs to discover + #[arg(long, default_value_t = 100, value_parser = parse_positive_usize)] + limit: usize, + + /// Restrict the source query to one or more root span ids + #[arg(long = "root-span-id")] + root_span_ids: Vec, + + /// Relative time window for source ref discovery when --root-span-id is not set + #[arg(long, env = "BT_DATASET_PIPELINE_WINDOW", default_value = "1d")] + window: String, + + /// Page size for discovery BTQL pagination + #[arg(long, default_value_t = 1000, value_parser = parse_positive_usize)] + page_size: usize, +} + +#[derive(Debug, Clone, Args)] +struct PipelineTransformOptions { + /// Maximum concurrent transform calls. Defaults to the logical CPU count. + #[arg(long, value_parser = parse_positive_usize)] + max_concurrency: Option, +} + +impl PipelineTransformOptions { + fn max_concurrency(&self) -> usize { + self.max_concurrency + .unwrap_or_else(default_transform_concurrency) + } +} + +#[derive(Debug, Clone, Args)] +struct PipelineArtifactArgs { + /// Root directory for pipeline artifacts. + #[arg(long, default_value = "bt-sync")] + root: PathBuf, +} + +#[derive(Debug, Clone, Args)] +struct PipelineRunArgs { + #[command(flatten)] + runner: PipelineRunnerArgs, + + #[command(flatten)] + source: PipelineSourceArgs, + + #[command(flatten)] + target: PipelineTargetArgs, + + #[command(flatten)] + pull: PipelinePullOptions, + + #[command(flatten)] + transform: PipelineTransformOptions, +} + +#[derive(Debug, Clone, Args)] +struct PipelinePullArgs { + #[command(flatten)] + runner: PipelineRunnerArgs, + + #[command(flatten)] + artifacts: PipelineArtifactArgs, + + #[command(flatten)] + source: PipelineSourceArgs, + + #[command(flatten)] + pull: PipelinePullOptions, + + /// Output JSONL file. Defaults to a managed path under --root. + #[arg(long)] + out: Option, +} + +#[derive(Debug, Clone, Args)] +struct PipelineTransformArgs { + #[command(flatten)] + runner: PipelineRunnerArgs, + + #[command(flatten)] + artifacts: PipelineArtifactArgs, + + #[command(flatten)] + source: PipelineSourceArgs, + + #[command(flatten)] + transform: PipelineTransformOptions, + + /// Input candidate JSONL file. Defaults to the latest pull output under --root. + #[arg(long = "in")] + input: Option, + + /// Output proposed dataset row JSONL file. Defaults to a managed path under --root. + #[arg(long)] + out: Option, +} + +#[derive(Debug, Clone, Args)] +struct PipelinePushArgs { + #[command(flatten)] + runner: PipelineRunnerArgs, + + #[command(flatten)] + artifacts: PipelineArtifactArgs, + + #[command(flatten)] + target: PipelineTargetArgs, + + /// Input transformed dataset row JSONL file. Defaults to the latest transform output under --root. + #[arg(long = "in")] + input: Option, + + /// Ignore previous sync push state and upload from the beginning. + #[arg(long)] + fresh: bool, +} + +pub async fn run(base: BaseArgs, args: PipelineArgs) -> Result<()> { + match args.command { + PipelineCommands::Run(args) => { + let inspect = inspect_with_overrides( + inspect_pipeline(&base, &args.runner).await?, + Some(&args.source), + Some(&args.target), + ); + let tempdir = + tempfile::tempdir().context("failed to create dataset pipeline temp dir")?; + let refs_path = tempdir.path().join("discovered.jsonl"); + print_pipeline_status(&base, "Pulling source refs..."); + let pull_result = discover_refs(&base, &inspect, &args.pull, &refs_path).await?; + print_pipeline_status( + &base, + format!( + "Pulled {} source ref(s) across {} page(s).", + pull_result.refs, pull_result.pages + ), + ); + + let (_, _, source_project) = + resolve_pipeline_source_context(&base, &inspect.source).await?; + let refs = read_jsonl_values(&refs_path)?; + let attachment_dir = tempdir.path().join("attachments"); + let transform_response = transform_source_refs( + &base, + &args.runner, + &source_project.id, + &inspect.source, + refs, + args.transform.max_concurrency(), + Some(&attachment_dir), + None, + ) + .await?; + let row_count = transform_response.rows.len(); + let inserted = + upload_dataset_rows(&base, &inspect.target, transform_response.rows).await?; + print_summary( + &base, + json!({ + "refs": transform_response.candidates, + "rows": row_count, + "inserted": inserted, + }), + false, + ) + } + PipelineCommands::Pull(args) => { + let inspect = inspect_with_overrides( + inspect_pipeline(&base, &args.runner).await?, + Some(&args.source), + None, + ); + pull_refs(&base, args, inspect).await + } + PipelineCommands::Transform(args) => transform_refs(&base, args).await, + PipelineCommands::Push(args) => push_rows(&base, args).await, + } +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +struct PipelineInspect { + source: PipelineSourceInspect, + target: PipelineTargetInspect, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +struct PipelineSourceInspect { + #[serde(skip_serializing_if = "Option::is_none")] + project_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + project_name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + org_name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + filter: Option, + #[serde(skip_serializing_if = "Option::is_none")] + scope: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +struct PipelineTargetInspect { + project_id: Option, + project_name: Option, + org_name: Option, + dataset_name: String, + description: Option, + metadata: Option, +} + +#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +enum PipelineScope { + Span, + Trace, +} + +impl PipelineScope { + fn from_source(source: &PipelineSourceInspect) -> Self { + source.scope.unwrap_or(PipelineScope::Span) + } +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct PipelineTransformResponse { + candidates: usize, + row_count: usize, + rows: Vec, +} + +#[derive(Debug)] +enum PipelineRunnerEvent { + Response(Value), + Progress(PipelineProgressEvent), + Error { + message: String, + stack: Option, + status: Option, + }, + Console { + stream: String, + message: String, + }, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct PipelineProgressEvent { + #[serde(rename = "type")] + kind_type: String, + kind: String, + #[serde(default)] + rows: Option, +} + +#[derive(Debug, Deserialize)] +struct PipelineRunnerErrorPayload { + message: String, + #[serde(default)] + stack: Option, + #[serde(default)] + status: Option, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +enum PipelineArtifactStage { + Pull, + Transform, +} + +impl PipelineArtifactStage { + fn command(self) -> &'static str { + match self { + PipelineArtifactStage::Pull => "pull", + PipelineArtifactStage::Transform => "transform", + } + } + + fn output_file(self) -> &'static str { + match self { + PipelineArtifactStage::Pull => "pulled.jsonl", + PipelineArtifactStage::Transform => "transformed.jsonl", + } + } + + fn spec_file(self) -> &'static str { + match self { + PipelineArtifactStage::Pull => "pull.spec.json", + PipelineArtifactStage::Transform => "transform.spec.json", + } + } + + fn manifest_file(self) -> &'static str { + match self { + PipelineArtifactStage::Pull => "pull.manifest.json", + PipelineArtifactStage::Transform => "transform.manifest.json", + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +struct PipelinePullArtifactOptions { + limit: usize, + root_span_ids: Vec, + window: String, + page_size: usize, +} + +impl From<&PipelinePullOptions> for PipelinePullArtifactOptions { + fn from(options: &PipelinePullOptions) -> Self { + Self { + limit: options.limit, + root_span_ids: options.root_span_ids.clone(), + window: options.window.clone(), + page_size: options.page_size, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +struct PipelineTransformArtifactOptions { + max_concurrency: usize, +} + +impl From<&PipelineTransformOptions> for PipelineTransformArtifactOptions { + fn from(options: &PipelineTransformOptions) -> Self { + Self { + max_concurrency: options.max_concurrency(), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +struct PipelineArtifactSpec { + schema_version: u32, + kind: String, + pipeline: String, + #[serde(skip_serializing_if = "Option::is_none")] + name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + cli_project: Option, + #[serde(skip_serializing_if = "Option::is_none")] + cli_org: Option, + stage: PipelineArtifactStage, + #[serde(skip_serializing_if = "Option::is_none")] + source: Option, + #[serde(skip_serializing_if = "Option::is_none")] + target: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pull: Option, + #[serde(skip_serializing_if = "Option::is_none")] + transform: Option, + #[serde(skip_serializing_if = "Option::is_none")] + input_path: Option, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +enum PipelineArtifactStatus { + Completed, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +struct PipelineArtifactManifest { + schema_version: u32, + spec_hash: String, + spec: PipelineArtifactSpec, + status: PipelineArtifactStatus, + stage: PipelineArtifactStage, + #[serde(skip_serializing_if = "Option::is_none")] + input_path: Option, + #[serde(skip_serializing_if = "Option::is_none")] + output_path: Option, + #[serde(skip_serializing_if = "Option::is_none")] + refs: Option, + #[serde(skip_serializing_if = "Option::is_none")] + candidates: Option, + #[serde(skip_serializing_if = "Option::is_none")] + rows: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pages: Option, + started_at: u64, + updated_at: u64, + completed_at: Option, +} + +#[derive(Debug, Clone)] +struct PipelineOutputArtifact { + spec_hash: String, + spec: PipelineArtifactSpec, + stage: PipelineArtifactStage, + spec_dir: PathBuf, + output_path: PathBuf, +} + +async fn inspect_pipeline(base: &BaseArgs, runner: &PipelineRunnerArgs) -> Result { + run_runner_json(base, "inspect", runner, None, |event| { + handle_pipeline_runner_event(None, event); + }) + .await +} + +fn inspect_with_overrides( + mut inspect: PipelineInspect, + source: Option<&PipelineSourceArgs>, + target: Option<&PipelineTargetArgs>, +) -> PipelineInspect { + if let Some(source) = source { + apply_source_overrides(&mut inspect.source, source); + } + if let Some(target) = target { + apply_target_overrides(&mut inspect.target, target); + } + inspect +} + +fn apply_source_overrides(source: &mut PipelineSourceInspect, args: &PipelineSourceArgs) { + if let Some(project_name) = args.source_project.as_deref() { + source.project_name = Some(project_name.to_string()); + source.project_id = None; + } + if let Some(project_id) = args.source_project_id.as_deref() { + source.project_id = Some(project_id.to_string()); + } + if let Some(org_name) = args.source_org.as_deref() { + source.org_name = Some(org_name.to_string()); + } + if let Some(filter) = args.source_filter.as_deref() { + source.filter = Some(filter.to_string()); + } +} + +fn source_with_resolved_project( + source: &PipelineSourceInspect, + project: &Project, + org_name: &str, +) -> PipelineSourceInspect { + let mut source = source.clone(); + source.project_id = Some(project.id.clone()); + source.project_name = Some(project.name.clone()); + if source.org_name.is_none() && !org_name.trim().is_empty() { + source.org_name = Some(org_name.to_string()); + } + source +} + +fn apply_target_overrides(target: &mut PipelineTargetInspect, args: &PipelineTargetArgs) { + if let Some(project_name) = args.target_project.as_deref() { + target.project_name = Some(project_name.to_string()); + target.project_id = None; + } + if let Some(project_id) = args.target_project_id.as_deref() { + target.project_id = Some(project_id.to_string()); + } + if let Some(org_name) = args.target_org.as_deref() { + target.org_name = Some(org_name.to_string()); + } + if let Some(dataset_name) = args.target_dataset.as_deref() { + target.dataset_name = dataset_name.to_string(); + } +} + +async fn build_runner_command( + base: &BaseArgs, + stage: &'static str, + runner: &PipelineRunnerArgs, + configure: F, +) -> Result +where + F: FnOnce(&mut Command, &'static str) -> Result<()>, +{ + let pipeline_file = runner.pipeline.clone(); + let files = vec![pipeline_file.clone()]; + let mut command = build_pipeline_runner_command(runner, &pipeline_file, &files)?; + + command.envs(resolved_runner_env(base).await?); + command.env("BT_DATASET_PIPELINE_STAGE", stage); + if let Some(name) = runner.name.as_deref() { + command.env("BT_DATASET_PIPELINE_NAME", name); + } + configure(&mut command, stage)?; + Ok(command) +} + +fn build_pipeline_runner_command( + runner: &PipelineRunnerArgs, + pipeline_file: &Path, + files: &[PathBuf], +) -> Result { + match pipeline_language(pipeline_file)? { + SourceLanguage::JsLike => { + let runner_script = materialize_dataset_pipeline_runner(RUNNER_FILE, RUNNER_SOURCE)?; + Ok(build_js_runner_command( + runner.runner.as_deref(), + &runner_script, + files, + )) + } + SourceLanguage::Python => { + let runner_script = + materialize_dataset_pipeline_runner(PY_RUNNER_FILE, PY_RUNNER_SOURCE)?; + let python = python_runner::resolve_python_interpreter_for_roots( + runner.runner.as_deref(), + &["BT_DATASET_PIPELINE_PYTHON"], + files, + ) + .context("No Python interpreter found. Install python, create a virtualenv, or pass --runner.")?; + let mut command = Command::new(python); + command.arg(runner_script).arg(pipeline_file); + Ok(command) + } + } +} + +fn materialize_dataset_pipeline_runner(file_name: &str, source: &str) -> Result { + materialize_runner_script(&dataset_pipeline_runner_cache_dir(), file_name, source) +} + +fn dataset_pipeline_runner_cache_dir() -> PathBuf { + let root = std::env::var_os("XDG_CACHE_HOME") + .map(PathBuf::from) + .or_else(|| std::env::var_os("HOME").map(|home| PathBuf::from(home).join(".cache"))) + .unwrap_or_else(std::env::temp_dir); + + root.join("bt") + .join("dataset-pipeline-runners") + .join(env!("CARGO_PKG_VERSION")) +} + +fn pipeline_language(pipeline_file: &Path) -> Result { + let extension = pipeline_file + .extension() + .and_then(|extension| extension.to_str()) + .with_context(|| { + format!( + "dataset pipeline file '{}' has no extension", + pipeline_file.display() + ) + })?; + classify_runtime_extension(extension).with_context(|| { + format!( + "unsupported dataset pipeline file extension '.{extension}'; expected .ts, .tsx, .js, .jsx, or .py" + ) + }) +} + +async fn pull_refs( + base: &BaseArgs, + args: PipelinePullArgs, + mut inspect: PipelineInspect, +) -> Result<()> { + let (_, source_client, source_project) = + resolve_pipeline_source_context(base, &inspect.source).await?; + inspect.source = + source_with_resolved_project(&inspect.source, &source_project, source_client.org_name()); + let spec = pipeline_pull_artifact_spec(base, &args.runner, &inspect.source, &args.pull); + let artifact = resolve_pipeline_output_artifact( + &args.artifacts.root, + &args.runner, + spec, + args.out.as_deref(), + None, + )?; + artifact.write_spec()?; + let started_at = epoch_seconds(); + let result = discover_refs(base, &inspect, &args.pull, &artifact.output_path).await?; + artifact.write_manifest(PipelineArtifactManifest { + schema_version: PIPELINE_ARTIFACT_SCHEMA_VERSION, + spec_hash: artifact.spec_hash.clone(), + spec: artifact.spec.clone(), + status: PipelineArtifactStatus::Completed, + stage: PipelineArtifactStage::Pull, + input_path: None, + output_path: Some(artifact.output_path.display().to_string()), + refs: Some(result.refs), + candidates: None, + rows: None, + pages: Some(result.pages), + started_at, + updated_at: epoch_seconds(), + completed_at: Some(epoch_seconds()), + })?; + print_summary( + base, + json!({ + "refs": result.refs, + "pages": result.pages, + "scope": match PipelineScope::from_source(&inspect.source) { PipelineScope::Trace => "trace", PipelineScope::Span => "span" }, + "source_project": source_project.name, + "source_project_id": source_project.id, + "out": artifact.output_path.display().to_string(), + }), + false, + ) +} + +fn pipeline_pull_artifact_spec( + base: &BaseArgs, + runner: &PipelineRunnerArgs, + source: &PipelineSourceInspect, + options: &PipelinePullOptions, +) -> PipelineArtifactSpec { + base_pipeline_artifact_spec(base, runner, PipelineArtifactStage::Pull) + .with_source(source.clone()) + .with_pull(options.into()) +} + +fn pipeline_transform_artifact_spec( + base: &BaseArgs, + runner: &PipelineRunnerArgs, + source: &PipelineSourceInspect, + options: &PipelineTransformOptions, + input_path: &Path, +) -> PipelineArtifactSpec { + base_pipeline_artifact_spec(base, runner, PipelineArtifactStage::Transform) + .with_source(source.clone()) + .with_transform(options.into()) + .with_input_path(input_path) +} + +fn base_pipeline_artifact_spec( + base: &BaseArgs, + runner: &PipelineRunnerArgs, + stage: PipelineArtifactStage, +) -> PipelineArtifactSpec { + PipelineArtifactSpec { + schema_version: PIPELINE_ARTIFACT_SCHEMA_VERSION, + kind: PIPELINE_ARTIFACT_OBJECT_TYPE.to_string(), + pipeline: runner.pipeline.display().to_string(), + name: runner.name.clone(), + cli_project: base.project.clone(), + cli_org: base.org_name.clone(), + stage, + source: None, + target: None, + pull: None, + transform: None, + input_path: None, + } +} + +impl PipelineArtifactSpec { + fn with_source(mut self, source: PipelineSourceInspect) -> Self { + self.source = Some(source); + self + } + + fn with_pull(mut self, pull: PipelinePullArtifactOptions) -> Self { + self.pull = Some(pull); + self + } + + fn with_transform(mut self, transform: PipelineTransformArtifactOptions) -> Self { + self.transform = Some(transform); + self + } + + fn with_input_path(mut self, input_path: &Path) -> Self { + self.input_path = Some(input_path.display().to_string()); + self + } +} + +fn resolve_pipeline_output_artifact( + root: &Path, + runner: &PipelineRunnerArgs, + spec: PipelineArtifactSpec, + explicit_out: Option<&Path>, + input_path: Option<&Path>, +) -> Result { + let spec_hash = stable_spec_hash(&spec)?; + let stage = spec.stage; + let hashed_spec_dir = artifact_spec_dir( + root, + PIPELINE_ARTIFACT_OBJECT_TYPE, + &pipeline_artifact_name(runner), + &spec_hash, + ); + let spec_dir = if matches!(stage, PipelineArtifactStage::Pull) { + hashed_spec_dir + } else { + input_path + .and_then(Path::parent) + .filter(|parent| !parent.as_os_str().is_empty()) + .map(Path::to_path_buf) + .unwrap_or(hashed_spec_dir) + }; + let output_path = explicit_out + .map(Path::to_path_buf) + .unwrap_or_else(|| spec_dir.join(stage.output_file())); + Ok(PipelineOutputArtifact { + spec_hash, + spec, + stage, + spec_dir, + output_path, + }) +} + +fn resolve_pipeline_input_path( + explicit_input: &Option, + root: &Path, + runner: &PipelineRunnerArgs, + stage: PipelineArtifactStage, +) -> Result { + if let Some(input) = explicit_input { + Ok(input.clone()) + } else { + resolve_latest_pipeline_stage_output(root, runner, stage) + } +} + +fn read_pipeline_stage_manifest_for_output( + output_path: &Path, + stage: PipelineArtifactStage, +) -> Result> { + let Some(parent) = output_path.parent() else { + return Ok(None); + }; + let manifest_path = parent.join(stage.manifest_file()); + if !manifest_path.exists() { + return Ok(None); + } + let manifest = read_json_file::(&manifest_path) + .with_context(|| format!("failed to read {}", manifest_path.display()))?; + if manifest.stage != stage || manifest.status != PipelineArtifactStatus::Completed { + return Ok(None); + } + Ok(Some(manifest)) +} + +fn base_with_pipeline_artifact_context( + base: &BaseArgs, + manifest: Option<&PipelineArtifactManifest>, +) -> BaseArgs { + let mut base = base.clone(); + if let Some(spec) = manifest.map(|manifest| &manifest.spec) { + if base.project.is_none() { + base.project = spec.cli_project.clone(); + } + if base.org_name.is_none() { + base.org_name = spec.cli_org.clone(); + } + } + base +} + +fn resolve_latest_pipeline_stage_output( + root: &Path, + runner: &PipelineRunnerArgs, + stage: PipelineArtifactStage, +) -> Result { + let base = artifact_base_dir( + root, + PIPELINE_ARTIFACT_OBJECT_TYPE, + &pipeline_artifact_name(runner), + ); + let mut best: Option<(u64, PathBuf)> = None; + if base.is_dir() { + for entry in + fs::read_dir(&base).with_context(|| format!("failed to read {}", base.display()))? + { + let entry = entry?; + if !entry.file_type()?.is_dir() { + continue; + } + let manifest_path = entry.path().join(stage.manifest_file()); + if !manifest_path.exists() { + continue; + } + let manifest = read_json_file::(&manifest_path)?; + if manifest.stage != stage || manifest.status != PipelineArtifactStatus::Completed { + continue; + } + let Some(output_path) = manifest + .output_path + .as_ref() + .map(PathBuf::from) + .filter(|path| path.exists()) + else { + continue; + }; + if best + .as_ref() + .map(|(best_time, _)| manifest.updated_at > *best_time) + .unwrap_or(true) + { + best = Some((manifest.updated_at, output_path)); + } + } + } + + best.map(|(_, path)| path).ok_or_else(|| { + anyhow::anyhow!( + "no completed dataset pipeline {} output found for '{}'. run `bt datasets pipeline {} {}` first or pass --in", + stage.command(), + pipeline_artifact_name(runner), + stage.command(), + runner.pipeline.display() + ) + }) +} + +fn pipeline_artifact_name(runner: &PipelineRunnerArgs) -> String { + runner + .name + .clone() + .or_else(|| { + runner + .pipeline + .file_stem() + .and_then(|stem| stem.to_str()) + .map(ToString::to_string) + }) + .unwrap_or_else(|| "pipeline".to_string()) +} + +impl PipelineOutputArtifact { + fn write_spec(&self) -> Result<()> { + write_json_atomic(&self.spec_dir.join(self.stage.spec_file()), &self.spec) + } + + fn write_manifest(&self, manifest: PipelineArtifactManifest) -> Result<()> { + write_json_atomic(&self.spec_dir.join(self.stage.manifest_file()), &manifest) + } +} + +async fn run_runner_json( + base: &BaseArgs, + stage: &'static str, + runner: &PipelineRunnerArgs, + request: Option<&Value>, + mut on_event: F, +) -> Result +where + T: DeserializeOwned, + F: FnMut(PipelineRunnerEvent), +{ + let mut command = build_runner_command(base, stage, runner, |_, _| Ok(())).await?; + let (listener, sse_guard) = runner_sse::bind_sse_listener("bt-dataset-pipeline")?; + let (tx, rx) = mpsc::unbounded_channel::(); + let sse_connected = Arc::new(AtomicBool::new(false)); + + let tx_sse = tx.clone(); + let sse_connected_for_task = Arc::clone(&sse_connected); + let mut sse_task = tokio::spawn(async move { + if let Err(err) = runner_sse::accept_and_read_sse_stream( + listener, + || { + sse_connected_for_task.store(true, Ordering::Relaxed); + }, + |event, data| { + handle_pipeline_sse_event(event, data, &tx_sse); + }, + ) + .await + { + let _ = tx_sse.send(PipelineRunnerEvent::Error { + message: format!("SSE stream error: {err}"), + stack: None, + status: None, + }); + } + }); + + let (sse_env_name, sse_env_value) = sse_guard.env( + "BT_DATASET_PIPELINE_SSE_SOCK", + "BT_DATASET_PIPELINE_SSE_ADDR", + ); + command.env(sse_env_name, sse_env_value); + command.stdin(Stdio::piped()); + command.stdout(Stdio::piped()); + command.stderr(Stdio::piped()); + + let mut child = command + .spawn() + .context("failed to start dataset pipeline runner")?; + if let Some(request) = request { + let mut stdin = child + .stdin + .take() + .context("dataset pipeline runner stdin was not available")?; + serde_json::to_writer(&mut stdin, request) + .context("failed to write dataset pipeline runner request")?; + stdin + .write_all(b"\n") + .context("failed to finish dataset pipeline runner request")?; + } + + if let Some(stdout) = child.stdout.take() { + forward_blocking_stream(stdout, "stdout", tx.clone()); + } + if let Some(stderr) = child.stderr.take() { + forward_blocking_stream(stderr, "stderr", tx.clone()); + } + drop(tx); + + let wait_task = tokio::task::spawn_blocking(move || child.wait()); + let mut response: Option = None; + let mut errors = Vec::::new(); + let wait = Box::pin(async move { + wait_task + .await + .context("dataset pipeline runner wait task failed")? + .context("dataset pipeline runner process failed") + }); + let status = runner_sse::drive_runner_events( + rx, + wait, + &mut sse_task, + &sse_connected, + "dataset pipeline runner exited without a status", + |event| match event { + PipelineRunnerEvent::Response(value) => { + response = Some(value); + } + PipelineRunnerEvent::Error { + message, + stack, + status: _, + } => { + errors.push(message.clone()); + if let Some(stack) = stack { + errors.push(stack); + } + on_event(PipelineRunnerEvent::Error { + message, + stack: None, + status: None, + }); + } + event => on_event(event), + }, + ) + .await?; + + let _sse_guard = sse_guard; + if !status.success() { + let detail = if errors.is_empty() { + String::new() + } else { + format!(": {}", errors.join("\n")) + }; + bail!( + "dataset pipeline runner failed with status {}{}", + status, + detail + ); + } + + let response = response.context("dataset pipeline runner did not send a response")?; + serde_json::from_value(response).context("failed to parse dataset pipeline runner response") +} + +fn handle_pipeline_sse_event( + event: Option, + data: String, + tx: &mpsc::UnboundedSender, +) { + match event.unwrap_or_default().as_str() { + "response" => { + if let Ok(value) = serde_json::from_str::(&data) { + let _ = tx.send(PipelineRunnerEvent::Response(value)); + } + } + "progress" => { + if let Ok(progress) = serde_json::from_str::(&data) { + if progress.kind_type == "dataset_pipeline_progress" { + let _ = tx.send(PipelineRunnerEvent::Progress(progress)); + } + } + } + "error" => { + if let Ok(payload) = serde_json::from_str::(&data) { + let _ = tx.send(PipelineRunnerEvent::Error { + message: payload.message, + stack: payload.stack, + status: payload.status, + }); + } else { + let _ = tx.send(PipelineRunnerEvent::Error { + message: data, + stack: None, + status: None, + }); + } + } + _ => {} + } +} + +fn forward_blocking_stream( + stream: T, + name: &'static str, + tx: mpsc::UnboundedSender, +) where + T: Read + Send + 'static, +{ + std::thread::spawn(move || { + let lines = BufReader::new(stream).lines(); + for line in lines { + match line { + Ok(message) => { + let _ = tx.send(PipelineRunnerEvent::Console { + stream: name.to_string(), + message, + }); + } + Err(err) => { + let _ = tx.send(PipelineRunnerEvent::Error { + message: format!("failed to read dataset pipeline runner {name}: {err}"), + stack: None, + status: None, + }); + break; + } + } + } + }); +} + +async fn transform_refs(base: &BaseArgs, args: PipelineTransformArgs) -> Result<()> { + let input_path = resolve_pipeline_input_path( + &args.input, + &args.artifacts.root, + &args.runner, + PipelineArtifactStage::Pull, + )?; + let pull_manifest = + read_pipeline_stage_manifest_for_output(&input_path, PipelineArtifactStage::Pull)?; + let inspect = inspect_pipeline(base, &args.runner).await?; + let mut source = pull_manifest + .as_ref() + .and_then(|manifest| manifest.spec.source.clone()) + .unwrap_or(inspect.source); + apply_source_overrides(&mut source, &args.source); + let source_base = base_with_pipeline_artifact_context(base, pull_manifest.as_ref()); + let (_, source_client, source_project) = + resolve_pipeline_source_context(&source_base, &source).await?; + source = source_with_resolved_project(&source, &source_project, source_client.org_name()); + let refs = read_jsonl_values(&input_path)?; + let spec = pipeline_transform_artifact_spec( + &source_base, + &args.runner, + &source, + &args.transform, + &input_path, + ); + let artifact = resolve_pipeline_output_artifact( + &args.artifacts.root, + &args.runner, + spec, + args.out.as_deref(), + Some(&input_path), + )?; + artifact.write_spec()?; + let attachment_dir = artifact.spec_dir.join("attachments"); + let started_at = epoch_seconds(); + let mut writer = create_jsonl_file_writer(&artifact.output_path)?; + let response = transform_source_refs( + &source_base, + &args.runner, + &source_project.id, + &source, + refs, + args.transform.max_concurrency(), + Some(&attachment_dir), + Some(&mut writer as &mut dyn Write), + ) + .await?; + writer.flush().context("failed to flush transform output")?; + artifact.write_manifest(PipelineArtifactManifest { + schema_version: PIPELINE_ARTIFACT_SCHEMA_VERSION, + spec_hash: artifact.spec_hash.clone(), + spec: artifact.spec.clone(), + status: PipelineArtifactStatus::Completed, + stage: PipelineArtifactStage::Transform, + input_path: Some(input_path.display().to_string()), + output_path: Some(artifact.output_path.display().to_string()), + refs: None, + candidates: Some(response.candidates), + rows: Some(response.row_count), + pages: None, + started_at, + updated_at: epoch_seconds(), + completed_at: Some(epoch_seconds()), + })?; + let row_count = response.row_count; + print_summary( + base, + json!({ + "candidates": response.candidates, + "rows": row_count, + "out": args + .out + .as_deref() + .unwrap_or(&artifact.output_path) + .display() + .to_string(), + }), + false, + ) +} + +async fn transform_source_refs( + base: &BaseArgs, + runner: &PipelineRunnerArgs, + source_project_id: &str, + source: &PipelineSourceInspect, + refs: Vec, + max_concurrency: usize, + attachment_dir: Option<&Path>, + mut row_writer: Option<&mut dyn Write>, +) -> Result { + if let Some(attachment_dir) = attachment_dir { + fs::create_dir_all(attachment_dir) + .with_context(|| format!("failed to create {}", attachment_dir.display()))?; + } + let progress = pipeline_progress_bar(base, refs.len() as u64, "Transforming candidates"); + progress.set_message("output rows: 0"); + let mut combined = PipelineTransformResponse { + candidates: 0, + row_count: 0, + rows: Vec::new(), + }; + let batch_size = max_concurrency.max(1); + let mut completed_candidates = 0usize; + for batch in refs.chunks(batch_size) { + let request = json!({ + "sourceProjectId": source_project_id, + "source": source, + "refs": batch, + "attachmentDir": attachment_dir.map(|path| path.display().to_string()), + "maxConcurrency": max_concurrency, + }); + let mut completed_in_batch = 0usize; + let mut batch_rows = 0usize; + let base_row_count = combined.row_count; + let response: PipelineTransformResponse = run_runner_json( + base, + "transform", + runner, + Some(&request), + |event| match event { + PipelineRunnerEvent::Progress(progress_event) + if progress_event.kind == "candidate" => + { + if completed_in_batch < batch.len() { + completed_in_batch += 1; + completed_candidates += 1; + batch_rows += progress_event.rows.unwrap_or(0); + progress.set_position(completed_candidates.min(refs.len()) as u64); + } + progress.set_message(format!("output rows: {}", base_row_count + batch_rows)); + } + event => handle_pipeline_runner_event(Some(&progress), event), + }, + ) + .await?; + validate_transform_response(&response)?; + if completed_in_batch < batch.len() { + completed_candidates += batch.len() - completed_in_batch; + progress.set_position(completed_candidates.min(refs.len()) as u64); + } + combined.candidates += response.candidates; + combined.row_count += response.row_count; + if let Some(writer) = row_writer.as_deref_mut() { + for row in response.rows { + write_jsonl_value(writer, &row).context("failed to write transform output row")?; + } + writer.flush().context("failed to flush transform output")?; + } else { + combined.rows.extend(response.rows); + } + progress.set_message(format!("output rows: {}", combined.row_count)); + } + progress.finish_and_clear(); + Ok(combined) +} + +fn pipeline_progress_bar(base: &BaseArgs, total: u64, label: &str) -> ProgressBar { + if base.json || base.quiet || !io::stderr().is_terminal() { + return ProgressBar::hidden(); + } + let pb = ProgressBar::new(total); + pb.set_style( + ProgressStyle::with_template( + "{spinner:.cyan} {prefix} [{bar:40.cyan/blue}] {pos}/{len} candidates ({percent:>3}%) | {msg}", + ) + .unwrap(), + ); + pb.set_prefix(label.to_string()); + pb.enable_steady_tick(Duration::from_millis(80)); + pb +} + +fn handle_pipeline_runner_event(progress: Option<&ProgressBar>, event: PipelineRunnerEvent) { + match event { + PipelineRunnerEvent::Console { stream, message } => { + let line = if stream == "stdout" { + format!("[pipeline stdout] {message}") + } else { + message + }; + if let Some(progress) = progress { + progress.suspend(|| eprintln!("{line}")); + } else { + eprintln!("{line}"); + } + } + PipelineRunnerEvent::Error { + message, + stack, + status, + } => { + let line = if let Some(status) = status { + format!("dataset pipeline runner error ({status}): {message}") + } else { + format!("dataset pipeline runner error: {message}") + }; + if let Some(progress) = progress { + progress.suspend(|| { + eprintln!("{line}"); + if let Some(stack) = stack { + eprintln!("{stack}"); + } + }); + } else { + eprintln!("{line}"); + if let Some(stack) = stack { + eprintln!("{stack}"); + } + } + } + PipelineRunnerEvent::Progress(_) | PipelineRunnerEvent::Response(_) => {} + } +} + +fn validate_transform_response(response: &PipelineTransformResponse) -> Result<()> { + if response.row_count != response.rows.len() { + bail!( + "dataset pipeline runner response rowCount {} did not match rows length {}", + response.row_count, + response.rows.len() + ); + } + Ok(()) +} + +async fn push_rows(base: &BaseArgs, args: PipelinePushArgs) -> Result<()> { + let inspect = inspect_with_overrides( + inspect_pipeline(base, &args.runner).await?, + None, + Some(&args.target), + ); + let input_path = resolve_pipeline_input_path( + &args.input, + &args.artifacts.root, + &args.runner, + PipelineArtifactStage::Transform, + )?; + let target_base = base_with_pipeline_target(base, &inspect.target); + let input_path = + materialize_deferred_attachments_for_push(base, &inspect.target, &input_path).await?; + + crate::sync::push_jsonl_file( + target_base, + SyncPushFileArgs { + object_ref: pipeline_target_dataset_ref(&inspect.target)?, + input: input_path, + root: args.artifacts.root, + fresh: args.fresh, + }, + ) + .await +} + +fn base_with_pipeline_target(base: &BaseArgs, target: &PipelineTargetInspect) -> BaseArgs { + let mut target_base = base.clone(); + if let Some(org_name) = target.org_name.as_deref() { + target_base.org_name = Some(org_name.to_string()); + } + if let Some(project_id) = target.project_id.as_deref() { + target_base.project = Some(project_id.to_string()); + } else if let Some(project_name) = target.project_name.as_deref() { + target_base.project = Some(project_name.to_string()); + } + target_base +} + +fn pipeline_target_dataset_ref(target: &PipelineTargetInspect) -> Result { + let dataset_name = target.dataset_name.trim(); + if dataset_name.is_empty() { + bail!("dataset pipeline target.datasetName cannot be empty"); + } + Ok(format!("dataset:{dataset_name}")) +} + +async fn materialize_deferred_attachments_for_push( + base: &BaseArgs, + target: &PipelineTargetInspect, + input_path: &Path, +) -> Result { + let rows = read_jsonl_values(input_path)?; + if !rows.iter().any(contains_deferred_attachment) { + return Ok(input_path.to_path_buf()); + } + + let target_ctx = resolve_target_context(base, target).await?; + let rows = + materialize_deferred_attachments(rows, &target_ctx.client, input_path.parent()).await?; + let output_path = input_path + .parent() + .filter(|parent| !parent.as_os_str().is_empty()) + .unwrap_or_else(|| Path::new(".")) + .join("materialized_for_push.jsonl"); + let mut writer = create_jsonl_file_writer(&output_path)?; + for row in rows { + write_jsonl_value(&mut writer, &row) + .with_context(|| format!("failed to write {}", output_path.display()))?; + } + writer + .flush() + .with_context(|| format!("failed to flush {}", output_path.display()))?; + Ok(output_path) +} + +async fn upload_dataset_rows( + base: &BaseArgs, + target: &PipelineTargetInspect, + rows: Vec, +) -> Result { + let target_ctx = resolve_target_context(base, target).await?; + let dataset = resolve_target_dataset(&target_ctx.client, target, &target_ctx.project).await?; + let rows = materialize_deferred_attachments(rows, &target_ctx.client, None).await?; + let records = prepare_pipeline_records(rows)?; + let inserted = records.len(); + + utils::submit_prepared_records( + &target_ctx, + &dataset.id, + &records, + false, + "Uploading dataset rows...", + "dataset pipeline upload failed", + ) + .await?; + + Ok(inserted) +} + +fn contains_deferred_attachment(value: &Value) -> bool { + match value { + Value::Object(object) => { + is_deferred_attachment_marker(object) + || object.values().any(contains_deferred_attachment) + } + Value::Array(items) => items.iter().any(contains_deferred_attachment), + _ => false, + } +} + +async fn materialize_deferred_attachments( + mut rows: Vec, + client: &ApiClient, + base_dir: Option<&Path>, +) -> Result> { + let mut specs = Vec::new(); + for row in &rows { + collect_deferred_attachment_specs(row, &mut specs)?; + } + if specs.is_empty() { + return Ok(rows); + } + + let mut replacements = HashMap::new(); + for spec in specs { + if replacements.contains_key(&spec.key) { + continue; + } + let reference = upload_deferred_attachment(client, &spec, base_dir) + .await + .with_context(|| format!("failed to upload deferred attachment {}", spec.filename))?; + replacements.insert(spec.key, reference); + } + + for row in &mut rows { + replace_deferred_attachment_specs(row, &replacements)?; + } + Ok(rows) +} + +#[derive(Debug, Clone)] +struct DeferredAttachmentSpec { + key: String, + filename: String, + content_type: String, + path: Option, + data: Option, + pretty: bool, +} + +fn collect_deferred_attachment_specs( + value: &Value, + specs: &mut Vec, +) -> Result<()> { + match value { + Value::Object(object) if is_deferred_attachment_marker(object) => { + specs.push(parse_deferred_attachment_spec(object)?); + } + Value::Object(object) => { + for value in object.values() { + collect_deferred_attachment_specs(value, specs)?; + } + } + Value::Array(items) => { + for value in items { + collect_deferred_attachment_specs(value, specs)?; + } + } + _ => {} + } + Ok(()) +} + +fn replace_deferred_attachment_specs( + value: &mut Value, + replacements: &HashMap, +) -> Result<()> { + match value { + Value::Object(object) if is_deferred_attachment_marker(object) => { + let spec = parse_deferred_attachment_spec(object)?; + let replacement = replacements + .get(&spec.key) + .with_context(|| format!("missing replacement for {}", spec.filename))?; + *value = replacement.clone(); + } + Value::Object(object) => { + for value in object.values_mut() { + replace_deferred_attachment_specs(value, replacements)?; + } + } + Value::Array(items) => { + for value in items { + replace_deferred_attachment_specs(value, replacements)?; + } + } + _ => {} + } + Ok(()) +} + +fn is_deferred_attachment_marker(object: &serde_json::Map) -> bool { + object + .get("type") + .and_then(Value::as_str) + .is_some_and(|value| value == "braintrust_deferred_attachment") +} + +fn parse_deferred_attachment_spec( + object: &serde_json::Map, +) -> Result { + let filename = object + .get("filename") + .and_then(Value::as_str) + .filter(|value| !value.trim().is_empty()) + .context("deferred attachment is missing filename")? + .to_string(); + let content_type = object + .get("content_type") + .and_then(Value::as_str) + .filter(|value| !value.trim().is_empty()) + .unwrap_or("application/json") + .to_string(); + let path = object + .get("path") + .and_then(Value::as_str) + .filter(|value| !value.trim().is_empty()) + .map(PathBuf::from); + let data = object.get("data").cloned(); + if path.is_none() && data.is_none() { + bail!("deferred attachment {filename} is missing path or data"); + } + let pretty = object + .get("pretty") + .and_then(Value::as_bool) + .unwrap_or(false); + let key = serde_json::to_string(object) + .context("failed to build deferred attachment replacement key")?; + + Ok(DeferredAttachmentSpec { + key, + filename, + content_type, + path, + data, + pretty, + }) +} + +async fn upload_deferred_attachment( + client: &ApiClient, + spec: &DeferredAttachmentSpec, + base_dir: Option<&Path>, +) -> Result { + let key = uuid::Uuid::new_v4().to_string(); + let request = json!({ + "key": key, + "filename": spec.filename, + "content_type": spec.content_type, + "org_id": client.org_id(), + }); + let metadata: AttachmentUploadMetadata = client + .post("/attachment", &request) + .await + .context("failed to request signed URL from API server")?; + let data = deferred_attachment_bytes(spec, base_dir)?; + let upload_result = + crate::http::put_signed_url_with_headers(&metadata.signed_url, data, &metadata.headers) + .await; + + let status = match &upload_result { + Ok(()) => json!({ "upload_status": "done" }), + Err(err) => json!({ "upload_status": "error", "error_message": err.to_string() }), + }; + let status_request = json!({ + "key": key, + "org_id": client.org_id(), + "status": status, + }); + let _: Value = client + .post("/attachment/status", &status_request) + .await + .context("failed to log attachment status")?; + upload_result.context("failed to upload attachment to object store")?; + + Ok(json!({ + "type": "braintrust_attachment", + "filename": spec.filename, + "content_type": spec.content_type, + "key": key, + })) +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct AttachmentUploadMetadata { + signed_url: String, + #[serde(default)] + headers: HashMap, +} + +fn deferred_attachment_bytes( + spec: &DeferredAttachmentSpec, + base_dir: Option<&Path>, +) -> Result> { + if let Some(path) = spec.path.as_ref() { + let path = if path.is_absolute() { + path.clone() + } else { + base_dir.unwrap_or_else(|| Path::new(".")).join(path) + }; + return fs::read(&path).with_context(|| format!("failed to read {}", path.display())); + } + + let data = spec + .data + .as_ref() + .context("deferred attachment is missing data")?; + let text = if spec.pretty { + serde_json::to_string_pretty(data) + } else { + serde_json::to_string(data) + } + .context("failed to serialize deferred attachment data")?; + Ok(text.into_bytes()) +} + +fn prepare_pipeline_records(rows: Vec) -> Result> { + let mut objects = Vec::with_capacity(rows.len()); + for (index, row) in rows.into_iter().enumerate() { + match row { + Value::Object(row) => objects.push(row), + _ => bail!("dataset pipeline row {} must be a JSON object", index + 1), + } + } + + records::prepare_records(objects, "id", false) + .context("dataset pipeline transform produced invalid dataset rows") +} + +async fn resolve_target_context( + base: &BaseArgs, + target: &PipelineTargetInspect, +) -> Result { + let mut target_base = base.clone(); + if let Some(org_name) = target.org_name.as_deref() { + target_base.org_name = Some(org_name.to_string()); + } + let ctx = login(&target_base).await?; + let client = ApiClient::new(&ctx)?; + let project = resolve_target_project(&client, target).await?; + Ok(ResolvedContext { + client, + app_url: ctx.app_url, + project, + }) +} + +async fn resolve_target_project( + client: &ApiClient, + target: &PipelineTargetInspect, +) -> Result { + if let Some(project_id) = target.project_id.as_deref() { + return Ok(Project { + id: project_id.to_string(), + name: target + .project_name + .clone() + .unwrap_or_else(|| project_id.to_string()), + org_id: String::new(), + description: None, + }); + } + let project_name = target + .project_name + .as_deref() + .context("dataset pipeline target requires projectName or projectId")?; + if let Some(project) = get_project_by_name(client, project_name).await? { + Ok(project) + } else { + create_project(client, project_name) + .await + .with_context(|| format!("project '{project_name}' not found, and creating it failed")) + } +} + +async fn resolve_target_dataset( + client: &ApiClient, + target: &PipelineTargetInspect, + project: &Project, +) -> Result { + let dataset_name = target.dataset_name.trim(); + if dataset_name.is_empty() { + bail!("dataset pipeline target.datasetName cannot be empty"); + } + + let datasets = datasets_api::list_datasets(client, &project.id).await?; + if let Some(dataset) = datasets + .iter() + .find(|dataset| dataset.id == dataset_name || dataset.name == dataset_name) + { + return Ok(dataset.clone()); + } + + if is_uuid_like(dataset_name) { + bail!( + "dataset id '{}' not found in project '{}'", + dataset_name, + project.name + ); + } + + datasets_api::create_dataset_with_metadata( + client, + &project.id, + dataset_name, + target.description.as_deref(), + target.metadata.as_ref(), + ) + .await + .with_context(|| format!("dataset '{dataset_name}' not found, and creating it failed")) +} + +async fn discover_refs( + base: &BaseArgs, + inspect: &PipelineInspect, + options: &PipelinePullOptions, + out: &Path, +) -> Result { + let (ctx, client, project) = resolve_pipeline_source_context(base, &inspect.source).await?; + let scope = PipelineScope::from_source(&inspect.source); + let limit = options.limit; + let filter = discovery_filter(&inspect.source, options)?; + + let mut writer = create_jsonl_file_writer(out)?; + + let result = discover_project_log_refs( + &client, + &ctx, + &project.id, + filter.as_ref(), + project_log_ref_scope(scope), + limit, + options.page_size, + |reference| { + write_jsonl_value(&mut writer, &reference.to_value())?; + writer.flush().context("failed to flush discovery output")?; + Ok(()) + }, + ) + .await?; + writer.flush().context("failed to flush discovery output")?; + + Ok(result) +} + +fn project_log_ref_scope(scope: PipelineScope) -> ProjectLogRefScope { + match scope { + PipelineScope::Trace => ProjectLogRefScope::Trace, + PipelineScope::Span => ProjectLogRefScope::Span, + } +} + +async fn resolve_pipeline_source_context( + base: &BaseArgs, + source: &PipelineSourceInspect, +) -> Result<(LoginContext, ApiClient, Project)> { + let mut source_base = base.clone(); + if let Some(org_name) = source.org_name.as_deref() { + source_base.org_name = Some(org_name.to_string()); + } + let ctx = login(&source_base).await?; + let client = ApiClient::new(&ctx)?; + let project = resolve_source_project(base, &client, source).await?; + Ok((ctx, client, project)) +} + +fn discovery_filter( + source: &PipelineSourceInspect, + options: &PipelinePullOptions, +) -> Result> { + let mut filters = Vec::new(); + if options.root_span_ids.is_empty() { + filters.push(discovery_window_filter(&options.window)?); + } + if let Some(filter) = source + .filter + .as_deref() + .map(str::trim) + .filter(|s| !s.is_empty()) + { + filters.push(json!({ "btql": filter })); + } + if !options.root_span_ids.is_empty() { + filters.push(root_span_id_filter(&options.root_span_ids)); + } + Ok(match filters.len() { + 0 => None, + 1 => filters.into_iter().next(), + _ => Some(json!({ "op": "and", "children": filters })), + }) +} + +fn discovery_window_filter(window: &str) -> Result { + let seconds = parse_duration_to_seconds(window) + .with_context(|| format!("invalid dataset pipeline pull window '{window}'"))?; + Ok(json!({ + "btql": format!("created >= NOW() - INTERVAL {seconds} SECOND") + })) +} + +fn root_span_id_filter(root_span_ids: &[String]) -> Value { + json!({ + "op": "in", + "left": { "op": "ident", "name": ["root_span_id"] }, + "right": { "op": "literal", "value": root_span_ids } + }) +} + +fn default_transform_concurrency() -> usize { + std::thread::available_parallelism() + .map(|parallelism| parallelism.get()) + .unwrap_or(16) +} + +async fn resolve_source_project( + base: &BaseArgs, + client: &ApiClient, + source: &PipelineSourceInspect, +) -> Result { + if let Some(project_id) = source.project_id.as_deref() { + return Ok(Project { + id: project_id.to_string(), + name: source + .project_name + .clone() + .unwrap_or_else(|| project_id.to_string()), + org_id: String::new(), + description: None, + }); + } + let configured_project = + crate::config::configured_project_for_context(base, Some(client.org_name())); + let project_name = source + .project_name + .as_deref() + .or(base.project.as_deref()) + .or(configured_project.as_deref()) + .context( + "dataset pipeline source requires projectName or projectId; pass --source-project or set an active project", + )?; + get_project_by_name(client, project_name) + .await? + .with_context(|| format!("project '{project_name}' not found")) +} + +fn print_summary(base: &BaseArgs, summary: Value, force_stderr: bool) -> Result<()> { + let object = summary + .as_object() + .context("dataset pipeline summary must be an object")?; + if base.json && !force_stderr { + println!("{}", serde_json::to_string(&summary)?); + return Ok(()); + } + let parts = object + .iter() + .map(|(key, value)| format!("{key}: {}", summary_value(value))) + .collect::>(); + eprintln!("{}", parts.join(", ")); + Ok(()) +} + +fn print_pipeline_status(base: &BaseArgs, message: impl AsRef) { + if !base.json && !base.quiet { + eprintln!("{}", message.as_ref()); + } +} + +fn summary_value(value: &Value) -> String { + match value { + Value::String(value) => value.clone(), + Value::Number(value) => value.to_string(), + Value::Bool(value) => value.to_string(), + Value::Null => "null".to_string(), + Value::Array(_) | Value::Object(_) => value.to_string(), + } +} + +fn is_uuid_like(value: &str) -> bool { + let bytes = value.as_bytes(); + if bytes.len() != 36 { + return false; + } + for (index, byte) in bytes.iter().enumerate() { + match index { + 8 | 13 | 18 | 23 => { + if *byte != b'-' { + return false; + } + } + _ if !byte.is_ascii_hexdigit() => return false, + _ => {} + } + } + true +} + +fn parse_positive_usize(value: &str) -> std::result::Result { + let parsed = value + .parse::() + .map_err(|_| format!("invalid positive integer '{value}'"))?; + if parsed == 0 { + return Err("value must be greater than 0".to_string()); + } + Ok(parsed) +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::path::PathBuf; + + use clap::Parser; + + #[derive(Debug, Parser)] + struct PipelineHarness { + #[command(subcommand)] + command: PipelineCommands, + } + + fn parse_pipeline(args: &[&str]) -> anyhow::Result { + let mut argv = vec!["bt"]; + argv.extend_from_slice(args); + Ok(PipelineHarness::try_parse_from(argv)?.command) + } + + fn test_base_args() -> BaseArgs { + BaseArgs { + json: false, + verbose: false, + verbose_source: None, + quiet: false, + quiet_source: None, + no_color: false, + no_input: false, + profile: None, + profile_explicit: false, + org_name: None, + project: None, + api_key: None, + api_key_source: None, + prefer_profile: false, + api_url: None, + app_url: None, + ca_cert: None, + env_file: None, + } + } + + fn test_source() -> PipelineSourceInspect { + PipelineSourceInspect { + project_id: None, + project_name: Some("Loop".to_string()), + org_name: None, + filter: None, + scope: Some(PipelineScope::Span), + } + } + + fn test_pull_options() -> PipelinePullOptions { + PipelinePullOptions { + limit: 100, + root_span_ids: Vec::new(), + window: "1d".to_string(), + page_size: 1000, + } + } + + #[test] + fn pipeline_pull_command_parses() { + let command = + parse_pipeline(&["pull", "pipeline.ts", "--limit", "7"]).expect("parse pull command"); + let PipelineCommands::Pull(pull) = command else { + panic!("expected pull command"); + }; + assert_eq!(pull.runner.pipeline, PathBuf::from("pipeline.ts")); + assert_eq!(pull.pull.limit, 7); + assert_eq!(pull.pull.window, "1d"); + } + + #[test] + fn pipeline_pull_window_parses() { + let command = + parse_pipeline(&["pull", "pipeline.ts", "--window", "2h"]).expect("parse pull window"); + let PipelineCommands::Pull(pull) = command else { + panic!("expected pull command"); + }; + assert_eq!(pull.pull.window, "2h"); + } + + #[test] + fn pipeline_pull_rejects_target_alias_for_limit() { + let err = parse_pipeline(&["pull", "pipeline.ts", "--target", "7"]) + .expect_err("target alias should not parse"); + + assert!(err.to_string().contains("unexpected argument '--target'")); + } + + #[test] + fn discovery_filter_adds_default_window() { + let filter = + discovery_filter(&test_source(), &test_pull_options()).expect("discovery filter"); + + assert_eq!( + filter, + Some(json!({ + "btql": "created >= NOW() - INTERVAL 86400 SECOND" + })) + ); + } + + #[test] + fn discovery_filter_combines_window_and_source_filter() { + let source = PipelineSourceInspect { + filter: Some("metadata.topic = 'test'".to_string()), + ..test_source() + }; + let filter = discovery_filter(&source, &test_pull_options()).expect("discovery filter"); + + assert_eq!( + filter, + Some(json!({ + "op": "and", + "children": [ + { "btql": "created >= NOW() - INTERVAL 86400 SECOND" }, + { "btql": "metadata.topic = 'test'" } + ] + })) + ); + } + + #[test] + fn discovery_filter_uses_root_span_filter_without_window() { + let options = PipelinePullOptions { + root_span_ids: vec!["root-1".to_string()], + ..test_pull_options() + }; + let filter = discovery_filter(&test_source(), &options).expect("discovery filter"); + + assert_eq!(filter, Some(root_span_id_filter(&["root-1".to_string()]))); + } + + #[test] + fn discovery_filter_rejects_invalid_window() { + let options = PipelinePullOptions { + window: "bad".to_string(), + ..test_pull_options() + }; + let err = discovery_filter(&test_source(), &options).expect_err("invalid window"); + + assert!(err + .to_string() + .contains("invalid dataset pipeline pull window 'bad'")); + } + + #[test] + fn prepare_pipeline_records_reuses_dataset_record_validation() { + let err = prepare_pipeline_records(vec![json!({ + "input": "hello", + "span_attributes": { "type": "llm" }, + })]) + .expect_err("unexpected dataset row fields should be rejected"); + + assert!(err + .to_string() + .contains("dataset pipeline transform produced invalid dataset rows")); + } + + #[test] + fn prepare_pipeline_records_uses_dataset_record_schema() { + let records = prepare_pipeline_records(vec![json!({ + "id": "row-1", + "input": { "question": "hello" }, + "expected": "world", + "tags": ["smoke"], + "metadata": { "source": "test" }, + "origin": { + "object_type": "project_logs", + "object_id": "source-project", + "id": "source-span" + } + })]) + .expect("valid dataset pipeline row should deserialize"); + + assert_eq!(records.len(), 1); + assert_eq!(records[0].id, "row-1"); + let upload = records[0].to_upload_row("target-dataset", false); + assert_eq!(upload.get("id"), Some(&json!("row-1"))); + assert_eq!(upload.get("dataset_id"), Some(&json!("target-dataset"))); + assert_eq!(upload.get("expected"), Some(&json!("world"))); + assert!(!upload.contains_key("span_id")); + assert!(!upload.contains_key("root_span_id")); + assert!(!upload.contains_key("project_id")); + } + + #[test] + fn transform_response_validation_rejects_row_count_mismatch() { + let response = PipelineTransformResponse { + candidates: 1, + row_count: 2, + rows: vec![json!({ "input": "one" })], + }; + + let err = + validate_transform_response(&response).expect_err("rowCount should match rows length"); + assert!(err.to_string().contains("rowCount 2")); + } + + #[test] + fn pipeline_target_dataset_ref_validates_dataset_name() { + let target = PipelineTargetInspect { + project_id: None, + project_name: Some("Target Project".to_string()), + org_name: None, + dataset_name: " Ground Truth ".to_string(), + description: None, + metadata: None, + }; + assert_eq!( + pipeline_target_dataset_ref(&target).expect("dataset ref"), + "dataset:Ground Truth" + ); + + let err = pipeline_target_dataset_ref(&PipelineTargetInspect { + dataset_name: " ".to_string(), + ..target + }) + .expect_err("empty dataset names should fail"); + assert!(err.to_string().contains("target.datasetName")); + } + + #[test] + fn pipeline_push_base_uses_target_org_and_project() { + let base = test_base_args(); + let target = PipelineTargetInspect { + project_id: Some("project-id".to_string()), + project_name: Some("Project Name".to_string()), + org_name: Some("target-org".to_string()), + dataset_name: "Dataset".to_string(), + description: None, + metadata: None, + }; + + let target_base = base_with_pipeline_target(&base, &target); + assert_eq!(target_base.org_name.as_deref(), Some("target-org")); + assert_eq!(target_base.project.as_deref(), Some("project-id")); + } + + #[test] + fn deferred_attachment_detection_finds_nested_marker() { + let row = json!({ + "id": "row-1", + "input": { + "full_trace": { + "type": "braintrust_deferred_attachment", + "kind": "json", + "filename": "trace.json", + "content_type": "application/json", + "path": "attachments/trace.json" + } + } + }); + + assert!(contains_deferred_attachment(&row)); + + let mut specs = Vec::new(); + collect_deferred_attachment_specs(&row, &mut specs).expect("collect specs"); + assert_eq!(specs.len(), 1); + assert_eq!(specs[0].filename, "trace.json"); + assert_eq!( + specs[0].path.as_deref(), + Some(Path::new("attachments/trace.json")) + ); + } + + #[test] + fn deferred_attachment_replacement_rewrites_nested_marker() { + let mut row = json!({ + "id": "row-1", + "input": { + "full_trace": { + "type": "braintrust_deferred_attachment", + "kind": "json", + "filename": "trace.json", + "content_type": "application/json", + "data": { "ok": true } + } + } + }); + let mut specs = Vec::new(); + collect_deferred_attachment_specs(&row, &mut specs).expect("collect specs"); + let replacement = json!({ + "type": "braintrust_attachment", + "filename": "trace.json", + "content_type": "application/json", + "key": "uploaded-key" + }); + let replacements = HashMap::from([(specs[0].key.clone(), replacement.clone())]); + + replace_deferred_attachment_specs(&mut row, &replacements).expect("replace specs"); + + assert_eq!(row["input"]["full_trace"], replacement); + } + + #[test] + fn typescript_runner_defers_json_attachments_during_transform() { + let Ok(strip_check) = Command::new("node") + .arg("--experimental-strip-types") + .arg("--eval") + .arg("") + .output() + else { + return; + }; + if !strip_check.status.success() { + return; + } + + let root = tempfile::tempdir().expect("tempdir"); + let node_modules = root.path().join("node_modules").join("braintrust"); + fs::create_dir_all(&node_modules).expect("create fake braintrust package"); + fs::write( + node_modules.join("package.json"), + r#"{"name":"braintrust","type":"module","exports":{".":{"import":"./index.mjs","require":"./index.cjs"}}}"#, + ) + .expect("write fake package.json"); + fs::write( + node_modules.join("index.cjs"), + r#" +const pipelines = []; + +class OriginalJSONAttachment { + constructor() { + throw new Error("original JSONAttachment should be shimmed"); + } +} + +module.exports = { + DatasetPipeline(definition) { + pipelines.push(definition); + return definition; + }, + getRegisteredDatasetPipelines() { + return pipelines; + }, + isDatasetPipelineDefinition(value) { + return !!value && typeof value.transform === "function"; + }, + LocalTrace: class { + constructor(options) { + this.options = options; + } + getConfiguration() { + return { root_span_id: this.options.rootSpanId }; + } + }, + _internalGetGlobalState() { + return { + loggedIn: true, + orgName: "source-org", + login: async function () { + return this; + }, + }; + }, + loginToState: async function ({ orgName }) { + return { + loggedIn: true, + orgName, + login: async function () { + return this; + }, + }; + }, + JSONAttachment: OriginalJSONAttachment, +}; +"#, + ) + .expect("write fake braintrust cjs module"); + fs::write( + node_modules.join("index.mjs"), + r#" +export function DatasetPipeline(definition) { + return definition; +} + +export class JSONAttachment { + constructor(data, options) { + const hook = globalThis.__BT_DATASET_PIPELINE_DEFER_JSON_ATTACHMENT__; + if (hook) { + return hook(data, options); + } + throw new Error("dataset pipeline deferred JSON hook was not installed"); + } +} +"#, + ) + .expect("write fake braintrust esm module"); + + let runner_path = root.path().join("dataset-pipeline-runner.ts"); + fs::write(&runner_path, RUNNER_SOURCE).expect("write runner source"); + let pipeline_path = root.path().join("pipeline.ts"); + fs::write( + &pipeline_path, + r#" +import { DatasetPipeline, JSONAttachment } from "braintrust"; + +export default DatasetPipeline({ + name: "ts-json-attachment-smoke", + source: { projectName: "source-project" }, + target: { projectName: "target-project", datasetName: "traces" }, + transform: () => ({ + input: { + full_trace: new JSONAttachment( + { ok: true }, + { filename: "trace.json", pretty: true }, + ), + }, + }), +}); +"#, + ) + .expect("write pipeline"); + + let attachment_dir = root.path().join("attachments"); + let request = json!({ + "refs": [{ + "root_span_id": "root-span", + "origin": { + "object_type": "project_logs", + "object_id": "source-project-id", + "id": "source-row", + "created": "2026-05-19T00:00:00Z", + "_xact_id": "123" + } + }], + "sourceProjectId": "source-project-id", + "attachmentDir": attachment_dir, + }); + let mut child = Command::new("node") + .arg("--experimental-strip-types") + .arg(&runner_path) + .arg(&pipeline_path) + .current_dir(root.path()) + .env("BT_DATASET_PIPELINE_STAGE", "transform") + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .expect("spawn node runner"); + child + .stdin + .as_mut() + .expect("runner stdin") + .write_all(request.to_string().as_bytes()) + .expect("write runner request"); + let output = child.wait_with_output().expect("runner output"); + assert!( + output.status.success(), + "runner failed\nstdout:\n{}\nstderr:\n{}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + + let response: Value = serde_json::from_slice(&output.stdout).expect("runner JSON response"); + assert_eq!(response["rowCount"], json!(1)); + let marker = &response["rows"][0]["input"]["full_trace"]; + assert_eq!(marker["type"], "braintrust_deferred_attachment"); + assert_eq!(marker["kind"], "json"); + assert_eq!(marker["filename"], "trace.json"); + assert_eq!(marker["content_type"], "application/json"); + assert!(marker.get("data").is_none()); + assert_eq!( + response["rows"][0]["origin"], + json!({ + "object_type": "project_logs", + "object_id": "source-project-id", + "id": "source-row", + "created": "2026-05-19T00:00:00Z", + "_xact_id": "123" + }) + ); + + let sidecar_path = marker["path"].as_str().expect("sidecar path"); + let sidecar = fs::read_to_string(sidecar_path).expect("read sidecar JSON"); + assert_eq!( + serde_json::from_str::(&sidecar).expect("parse sidecar"), + json!({ "ok": true }) + ); + assert!(sidecar.contains("\n \"ok\": true\n")); + } + + #[test] + fn pipeline_source_artifact_records_resolved_project() { + let source = PipelineSourceInspect { + project_id: None, + project_name: None, + org_name: None, + filter: Some("span_attributes.type = 'llm'".to_string()), + scope: Some(PipelineScope::Span), + }; + let project = Project { + id: "project-id".to_string(), + name: "Loop".to_string(), + org_id: "org-id".to_string(), + description: None, + }; + + let resolved = source_with_resolved_project(&source, &project, "braintrustdata.com"); + + assert_eq!(resolved.project_id.as_deref(), Some("project-id")); + assert_eq!(resolved.project_name.as_deref(), Some("Loop")); + assert_eq!(resolved.org_name.as_deref(), Some("braintrustdata.com")); + assert_eq!(resolved.filter, source.filter); + assert_eq!(resolved.scope, source.scope); + } + + #[test] + fn pipeline_transform_base_inherits_pull_artifact_context() { + let base = test_base_args(); + let manifest = PipelineArtifactManifest { + schema_version: PIPELINE_ARTIFACT_SCHEMA_VERSION, + spec_hash: "hash".to_string(), + spec: PipelineArtifactSpec { + schema_version: PIPELINE_ARTIFACT_SCHEMA_VERSION, + kind: PIPELINE_ARTIFACT_OBJECT_TYPE.to_string(), + pipeline: "facet_pipeline.py".to_string(), + name: None, + cli_project: Some("Loop".to_string()), + cli_org: Some("braintrustdata.com".to_string()), + stage: PipelineArtifactStage::Pull, + source: None, + target: None, + pull: None, + transform: None, + input_path: None, + }, + status: PipelineArtifactStatus::Completed, + stage: PipelineArtifactStage::Pull, + input_path: None, + output_path: None, + refs: Some(1), + candidates: None, + rows: None, + pages: Some(1), + started_at: 1, + updated_at: 2, + completed_at: Some(2), + }; + + let inherited = base_with_pipeline_artifact_context(&base, Some(&manifest)); + + assert_eq!(inherited.project.as_deref(), Some("Loop")); + assert_eq!(inherited.org_name.as_deref(), Some("braintrustdata.com")); + } + + #[test] + fn pipeline_artifacts_default_to_sync_root_shape() { + let root = tempfile::tempdir().expect("tempdir"); + let runner = PipelineRunnerArgs { + pipeline: PathBuf::from("facet_pipeline.py"), + name: None, + runner: None, + }; + let spec = + base_pipeline_artifact_spec(&test_base_args(), &runner, PipelineArtifactStage::Pull); + + let artifact = resolve_pipeline_output_artifact(root.path(), &runner, spec, None, None) + .expect("artifact path"); + + assert!(artifact + .output_path + .starts_with(root.path().join("dataset_pipeline_facet_pipeline"))); + assert_eq!(artifact.output_path.file_name().unwrap(), "pulled.jsonl"); + } + + #[test] + fn pipeline_input_defaults_to_latest_completed_stage_output() { + let root = tempfile::tempdir().expect("tempdir"); + let runner = PipelineRunnerArgs { + pipeline: PathBuf::from("facet_pipeline.py"), + name: None, + runner: None, + }; + let spec = + base_pipeline_artifact_spec(&test_base_args(), &runner, PipelineArtifactStage::Pull); + let artifact = resolve_pipeline_output_artifact(root.path(), &runner, spec, None, None) + .expect("artifact path"); + + artifact.write_spec().expect("write spec"); + crate::sync::write_jsonl_values( + Some(&artifact.output_path), + &[json!({ "root_span_id": "root-1" })], + ) + .expect("write output"); + artifact + .write_manifest(PipelineArtifactManifest { + schema_version: PIPELINE_ARTIFACT_SCHEMA_VERSION, + spec_hash: artifact.spec_hash.clone(), + spec: artifact.spec.clone(), + status: PipelineArtifactStatus::Completed, + stage: PipelineArtifactStage::Pull, + input_path: None, + output_path: Some(artifact.output_path.display().to_string()), + refs: Some(1), + candidates: None, + rows: None, + pages: Some(1), + started_at: 1, + updated_at: 2, + completed_at: Some(2), + }) + .expect("write manifest"); + + let resolved = + resolve_pipeline_input_path(&None, root.path(), &runner, PipelineArtifactStage::Pull) + .expect("default input"); + + assert_eq!(resolved, artifact.output_path); + } + + #[test] + fn pipeline_transform_output_defaults_to_pull_artifact_dir() { + let root = tempfile::tempdir().expect("tempdir"); + let runner = PipelineRunnerArgs { + pipeline: PathBuf::from("facet_pipeline.py"), + name: None, + runner: None, + }; + let source = PipelineSourceInspect { + project_id: None, + project_name: Some("Loop".to_string()), + org_name: None, + filter: None, + scope: Some(PipelineScope::Span), + }; + let pull_spec = + base_pipeline_artifact_spec(&test_base_args(), &runner, PipelineArtifactStage::Pull); + let pull_artifact = + resolve_pipeline_output_artifact(root.path(), &runner, pull_spec, None, None) + .expect("pull artifact"); + let transform_spec = pipeline_transform_artifact_spec( + &test_base_args(), + &runner, + &source, + &PipelineTransformOptions { + max_concurrency: Some(16), + }, + &pull_artifact.output_path, + ); + + let transform_artifact = resolve_pipeline_output_artifact( + root.path(), + &runner, + transform_spec, + None, + Some(&pull_artifact.output_path), + ) + .expect("transform artifact"); + + assert_eq!(transform_artifact.spec_dir, pull_artifact.spec_dir); + assert_eq!( + transform_artifact.output_path, + pull_artifact.spec_dir.join("transformed.jsonl") + ); + } +} diff --git a/src/datasets/records.rs b/src/datasets/records.rs index de84999a..092afe3a 100644 --- a/src/datasets/records.rs +++ b/src/datasets/records.rs @@ -206,7 +206,7 @@ fn expect_record_object(value: Value, record_number: Option) -> Result>, id_field: &str, require_ids: bool, diff --git a/src/eval.rs b/src/eval.rs index 29b01120..6402e57f 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -3,9 +3,9 @@ use std::ffi::{OsStr, OsString}; use std::io::IsTerminal; use std::path::{Path, PathBuf}; use std::process::{ExitStatus, Stdio}; -use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; -use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use std::time::{Duration, SystemTime}; use actix_web::dev::Service; use actix_web::http::header::{ @@ -27,8 +27,6 @@ use reqwest::Client; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use strip_ansi_escapes::strip; -use tokio::io::{AsyncBufReadExt, BufReader}; -use tokio::net::UnixListener; use tokio::process::Command; use tokio::sync::mpsc; use unicode_width::UnicodeWidthStr; @@ -41,8 +39,10 @@ use ratatui::widgets::{Cell, Row, Table}; use ratatui::Terminal; use crate::args::BaseArgs; -use crate::auth::resolved_auth_env; +use crate::auth::resolved_runner_env; +use crate::js_runner; use crate::python_runner; +use crate::runner_sse; use crate::ui::{animations_enabled, is_quiet}; const MAX_NAME_LENGTH: usize = 40; @@ -57,7 +57,6 @@ const HEADER_BT_AUTH_TOKEN: &str = "x-bt-auth-token"; const HEADER_BT_ORG_NAME: &str = "x-bt-org-name"; const HEADER_CORS_REQ_PRIVATE_NETWORK: &str = "access-control-request-private-network"; const HEADER_CORS_ALLOW_PRIVATE_NETWORK: &str = "access-control-allow-private-network"; -const SSE_SOCKET_BIND_MAX_ATTEMPTS: u8 = 16; const EVAL_NODE_MAX_OLD_SPACE_SIZE_MB: usize = 8192; const MAX_DEFERRED_EVAL_ERRORS: usize = 8; const DEFAULT_EVAL_SAMPLE_SEED: u64 = 0; @@ -82,8 +81,6 @@ fn parse_positive_usize(value: &str) -> std::result::Result { } Ok(parsed) } -static SSE_SOCKET_COUNTER: AtomicU64 = AtomicU64::new(0); - struct EvalRunOutput { status: ExitStatus, dependencies: Vec, @@ -94,7 +91,7 @@ struct EvalRunnerProcess { rx: mpsc::UnboundedReceiver, sse_task: tokio::task::JoinHandle<()>, sse_connected: Arc, - _socket_cleanup_guard: SocketCleanupGuard, + _sse_guard: runner_sse::SseListenerGuard, } struct EvalProcessOutput { @@ -219,22 +216,6 @@ const JS_RUNNER_SOURCE: &str = include_str!("../scripts/eval-runner.ts"); const PY_RUNNER_SOURCE: &str = include_str!("../scripts/eval-runner.py"); const PYTHON_INTERPRETER_ENV_OVERRIDES: &[&str] = &["BT_EVAL_PYTHON_RUNNER", "BT_EVAL_PYTHON"]; -struct SocketCleanupGuard { - path: PathBuf, -} - -impl SocketCleanupGuard { - fn new(path: PathBuf) -> Self { - Self { path } - } -} - -impl Drop for SocketCleanupGuard { - fn drop(&mut self) { - let _ = std::fs::remove_file(&self.path); - } -} - #[derive(Debug, Copy, Clone, Eq, PartialEq, ValueEnum)] pub enum EvalLanguage { #[value(alias = "js")] @@ -771,32 +752,30 @@ async fn spawn_eval_runner( let (js_runner, py_runner) = prepare_eval_runners()?; let force_esm = matches!(js_mode, JsMode::ForceEsm); - let (listener, socket_path, socket_cleanup_guard) = bind_sse_listener()?; + let (listener, sse_guard) = runner_sse::bind_sse_listener("bt-eval")?; let (tx, rx) = mpsc::unbounded_channel(); let sse_connected = Arc::new(AtomicBool::new(false)); let tx_sse = tx.clone(); let sse_connected_for_task = Arc::clone(&sse_connected); let sse_task = tokio::spawn(async move { - match listener.accept().await { - Ok((stream, _)) => { + if let Err(err) = runner_sse::accept_and_read_sse_stream( + listener, + || { sse_connected_for_task.store(true, Ordering::Relaxed); - if let Err(err) = read_sse_stream(stream, tx_sse.clone()).await { - let _ = tx_sse.send(EvalEvent::Error { - message: format!("SSE stream error: {err}"), - stack: None, - status: None, - }); - } - } - Err(err) => { - let _ = tx_sse.send(EvalEvent::Error { - message: format!("Failed to accept SSE connection: {err}"), - stack: None, - status: None, - }); - } - }; + }, + |event, data| { + handle_sse_event(event, data, &tx_sse); + }, + ) + .await + { + let _ = tx_sse.send(EvalEvent::Error { + message: format!("SSE stream error: {err}"), + stack: None, + status: None, + }); + } }); let (mut cmd, runner_kind) = match language { @@ -820,7 +799,7 @@ async fn spawn_eval_runner( set_node_heap_size_env(&mut cmd); } - cmd.envs(build_env(base).await?); + cmd.envs(resolved_runner_env(base).await?); for (key, value) in extra_env { cmd.env(key, value); } @@ -895,10 +874,8 @@ async fn spawn_eval_runner( serde_json::to_string(&payload).context("failed to serialize matrix axes")?; cmd.env("BT_EVAL_MATRIX_JSON", serialized); } - cmd.env( - "BT_EVAL_SSE_SOCK", - socket_path.to_string_lossy().to_string(), - ); + let (sse_env_name, sse_env_value) = sse_guard.env("BT_EVAL_SSE_SOCK", "BT_EVAL_SSE_ADDR"); + cmd.env(sse_env_name, sse_env_value); cmd.stdout(Stdio::piped()); cmd.stderr(Stdio::piped()); @@ -910,7 +887,14 @@ async fn spawn_eval_runner( if let Some(stdout) = stdout { let tx_stdout = tx.clone(); tokio::spawn(async move { - if let Err(err) = forward_stream(stdout, "stdout", tx_stdout).await { + if let Err(err) = runner_sse::forward_stream(stdout, "stdout", |stream, message| { + let _ = tx_stdout.send(EvalEvent::Console { + stream: stream.to_string(), + message, + }); + }) + .await + { eprintln!("Failed to read eval stdout: {err}"); } }); @@ -919,7 +903,14 @@ async fn spawn_eval_runner( if let Some(stderr) = stderr { let tx_stderr = tx.clone(); tokio::spawn(async move { - if let Err(err) = forward_stream(stderr, "stderr", tx_stderr).await { + if let Err(err) = runner_sse::forward_stream(stderr, "stderr", |stream, message| { + let _ = tx_stderr.send(EvalEvent::Console { + stream: stream.to_string(), + message, + }); + }) + .await + { eprintln!("Failed to read eval stderr: {err}"); } }); @@ -933,77 +924,71 @@ async fn spawn_eval_runner( rx, sse_task, sse_connected, - _socket_cleanup_guard: socket_cleanup_guard, + _sse_guard: sse_guard, }, runner_kind, }) } async fn drive_eval_runner( - mut process: EvalRunnerProcess, + process: EvalRunnerProcess, console_policy: ConsolePolicy, mut on_event: F, ) -> Result where F: FnMut(EvalEvent), { - let mut status = None; + let EvalRunnerProcess { + mut child, + rx, + mut sse_task, + sse_connected, + _sse_guard, + } = process; let mut dependency_files: Vec = Vec::new(); let mut error_messages: Vec = Vec::new(); let mut stderr_lines: Vec = Vec::new(); - - loop { - tokio::select! { - event = process.rx.recv() => { - match event { - Some(EvalEvent::Dependencies { files }) => { - dependency_files.extend(files.clone()); - on_event(EvalEvent::Dependencies { files }); - } - Some(EvalEvent::Error { message, stack, status }) => { - error_messages.push(message.clone()); - if let Some(stack) = stack.as_ref() { - error_messages.push(stack.clone()); - } - on_event(EvalEvent::Error { message, stack, status }); - } - Some(EvalEvent::Console { stream, message }) => { - if stream == "stderr" && matches!(console_policy, ConsolePolicy::BufferStderr) - { - stderr_lines.push(message); - } else { - on_event(EvalEvent::Console { stream, message }); - } - } - Some(event) => on_event(event), - None => { - if status.is_none() { - status = Some(process.child.wait().await.context("eval runner process failed")?); - if !process.sse_connected.load(Ordering::Relaxed) { - process.sse_task.abort(); - } - } - break; - } + let wait = Box::pin(async { child.wait().await.context("eval runner process failed") }); + let status = runner_sse::drive_runner_events( + rx, + wait, + &mut sse_task, + &sse_connected, + "eval runner process exited without a status", + |event| match event { + EvalEvent::Dependencies { files } => { + dependency_files.extend(files.clone()); + on_event(EvalEvent::Dependencies { files }); + } + EvalEvent::Error { + message, + stack, + status, + } => { + error_messages.push(message.clone()); + if let Some(stack) = stack.as_ref() { + error_messages.push(stack.clone()); } + on_event(EvalEvent::Error { + message, + stack, + status, + }); } - exit_status = process.child.wait(), if status.is_none() => { - status = Some(exit_status.context("eval runner process failed")?); - if !process.sse_connected.load(Ordering::Relaxed) { - process.sse_task.abort(); + EvalEvent::Console { stream, message } => { + if stream == "stderr" && matches!(console_policy, ConsolePolicy::BufferStderr) { + stderr_lines.push(message); + } else { + on_event(EvalEvent::Console { stream, message }); } } - } - - if status.is_some() && process.rx.is_closed() { - break; - } - } - - let _ = process.sse_task.await; + event => on_event(event), + }, + ) + .await?; Ok(EvalProcessOutput { - status: status.context("eval runner process exited without a status")?, + status, dependency_files, error_messages, stderr_lines, @@ -2133,21 +2118,6 @@ fn format_watch_paths(paths: &[PathBuf]) -> String { } } -async fn build_env(base: &BaseArgs) -> Result> { - let mut envs = resolved_auth_env(base).await?; - let resolved_org = envs - .iter() - .find_map(|(key, value)| (key == "BRAINTRUST_ORG_NAME").then_some(value.as_str())); - let project = base - .project - .clone() - .or_else(|| crate::config::configured_project_for_context(base, resolved_org)); - if let Some(project) = &project { - envs.push(("BRAINTRUST_DEFAULT_PROJECT".to_string(), project.clone())); - } - Ok(envs) -} - fn detect_eval_language( files: &[String], language_override: Option, @@ -2348,7 +2318,7 @@ fn build_js_plan( ) -> Result { if let Some(explicit) = runner_override { let resolved_runner = resolve_js_runner_command(explicit, files); - if is_deno_runner(explicit) || is_deno_runner_path(resolved_runner.as_ref()) { + if is_deno_runner(explicit) || js_runner::is_deno_runner_path(resolved_runner.as_ref()) { let runner_script = prepare_js_runner_in_cwd()?; return Ok(JsRunnerPlan { cmd: build_deno_js_command(resolved_runner.as_os_str(), &runner_script, files), @@ -2363,7 +2333,7 @@ fn build_js_plan( } if let Some(auto_runner) = find_js_runner_binary(files) { - if is_deno_runner_path(&auto_runner) { + if js_runner::is_deno_runner_path(&auto_runner) { let runner_script = prepare_js_runner_in_cwd()?; return Ok(JsRunnerPlan { cmd: build_deno_js_command(auto_runner.as_os_str(), &runner_script, files), @@ -2387,7 +2357,7 @@ fn build_js_plan( fn build_vite_node_fallback_command(runner: &Path, files: &[String]) -> Result { if let Some(path) = find_node_module_bin_for_files("vite-node", files) - .or_else(|| find_binary_in_path(&["vite-node"])) + .or_else(|| js_runner::find_binary_in_path(&["vite-node"])) { let mut command = Command::new(path); command.arg(runner).args(files); @@ -2414,15 +2384,7 @@ fn build_deno_js_command( } fn deno_js_command_args(runner: &Path, files: &[String]) -> Vec { - let mut args = vec![ - OsString::from("run"), - OsString::from("-A"), - OsString::from("--node-modules-dir=auto"), - OsString::from("--unstable-detect-cjs"), - runner.as_os_str().to_os_string(), - ]; - args.extend(files.iter().map(OsString::from)); - args + js_runner::deno_runner_args(runner, &files_as_paths(files)) } fn build_python_command( @@ -2466,104 +2428,38 @@ fn python_runner_search_roots(files: &[String]) -> Vec { } fn find_js_runner_binary(files: &[String]) -> Option { - // Prefer local project bins first, then PATH. `tsx` remains the preferred - // default, with other common TS runners as fallback. - const RUNNER_CANDIDATES: &[&str] = &["tsx", "vite-node", "ts-node", "ts-node-esm", "deno"]; - - for candidate in RUNNER_CANDIDATES { - if let Some(path) = find_node_module_bin_for_files(candidate, files) { - return Some(path); - } - } - - find_binary_in_path(RUNNER_CANDIDATES) + js_runner::find_js_runner_binary(&files_as_paths(files)) } fn resolve_js_runner_command(runner: &str, files: &[String]) -> PathBuf { - if is_path_like_runner(runner) { - return PathBuf::from(runner); - } - - find_node_module_bin_for_files(runner, files) - .or_else(|| find_binary_in_path(&[runner])) - .unwrap_or_else(|| PathBuf::from(runner)) -} - -fn is_path_like_runner(runner: &str) -> bool { - let path = Path::new(runner); - path.is_absolute() || runner.contains('/') || runner.contains('\\') || runner.starts_with('.') + js_runner::resolve_js_runner_command(runner, &files_as_paths(files)) } fn find_node_module_bin_for_files(binary: &str, files: &[String]) -> Option { - let search_roots = js_runner_search_roots(files); - for root in &search_roots { - if let Some(path) = find_node_module_bin(binary, root) { - return Some(path); - } - } - None + js_runner::find_node_module_bin_for_files(binary, &files_as_paths(files)) } -fn js_runner_search_roots(files: &[String]) -> Vec { - let mut search_roots = Vec::new(); - if let Ok(cwd) = std::env::current_dir() { - search_roots.push(cwd.clone()); - for file in files { - let path = PathBuf::from(file); - let absolute = if path.is_absolute() { - path - } else { - cwd.join(path) - }; - if let Some(parent) = absolute.parent() { - search_roots.push(parent.to_path_buf()); - } - } - } - search_roots +fn files_as_paths(files: &[String]) -> Vec { + files.iter().map(PathBuf::from).collect() } fn is_deno_runner(runner: &str) -> bool { - let file_name = Path::new(runner) - .file_name() - .and_then(|value| value.to_str()) - .unwrap_or(runner); - file_name.eq_ignore_ascii_case("deno") || file_name.eq_ignore_ascii_case("deno.exe") -} - -fn is_deno_runner_path(runner: &Path) -> bool { - runner - .file_name() - .and_then(|value| value.to_str()) - .map(|name| name.eq_ignore_ascii_case("deno") || name.eq_ignore_ascii_case("deno.exe")) - .unwrap_or(false) + js_runner::is_deno_runner_path(Path::new(runner)) } fn select_js_runner_entrypoint(default_runner: &Path, runner_command: &Path) -> Result { - if is_ts_node_runner(runner_command) { + if js_runner::is_ts_node_runner_path(runner_command) { return prepare_js_runner_in_cwd(); } Ok(default_runner.to_path_buf()) } fn prepare_js_runner_in_cwd() -> Result { - let cwd = std::env::current_dir().context("failed to resolve current working directory")?; - let cache_dir = cwd - .join(".bt") - .join("eval-runners") - .join(env!("CARGO_PKG_VERSION")); - std::fs::create_dir_all(&cache_dir).with_context(|| { - format!( - "failed to create eval runner cache dir {}", - cache_dir.display() - ) - })?; - materialize_runner_script(&cache_dir, JS_RUNNER_FILE, JS_RUNNER_SOURCE) + js_runner::materialize_runner_script_in_cwd("eval-runners", JS_RUNNER_FILE, JS_RUNNER_SOURCE) } fn runner_bin_name(runner_command: &Path) -> Option { - let name = runner_command.file_name()?.to_str()?.to_ascii_lowercase(); - Some(name.strip_suffix(".cmd").unwrap_or(&name).to_string()) + js_runner::runner_bin_name(runner_command) } fn runner_kind_for_bin(runner_command: &Path) -> RunnerKind { @@ -2595,90 +2491,6 @@ fn set_node_heap_size_env(command: &mut Command) { command.env("NODE_OPTIONS", merged); } -fn is_ts_node_runner(runner_command: &Path) -> bool { - runner_bin_name(runner_command).is_some_and(|n| n == "ts-node" || n == "ts-node-esm") -} - -fn find_node_module_bin(binary: &str, start: &Path) -> Option { - let mut current = Some(start); - while let Some(dir) = current { - let base = dir.join("node_modules").join(".bin").join(binary); - if base.is_file() { - return Some(base); - } - if cfg!(windows) { - let cmd = base.with_extension("cmd"); - if cmd.is_file() { - return Some(cmd); - } - } - current = dir.parent(); - } - None -} - -fn find_binary_in_path(candidates: &[&str]) -> Option { - let paths = std::env::var_os("PATH")?; - for dir in std::env::split_paths(&paths) { - for candidate in candidates { - let path = dir.join(candidate); - if path.is_file() { - return Some(path); - } - if cfg!(windows) { - let cmd = path.with_extension("cmd"); - if cmd.is_file() { - return Some(cmd); - } - } - } - } - None -} - -fn build_sse_socket_path() -> Result { - let pid = std::process::id(); - let serial = SSE_SOCKET_COUNTER.fetch_add(1, Ordering::Relaxed); - let now = SystemTime::now() - .duration_since(UNIX_EPOCH) - .context("failed to read system time")? - .as_nanos(); - Ok(std::env::temp_dir().join(format!("bt-eval-{pid}-{now}-{serial}.sock"))) -} - -fn bind_sse_listener() -> Result<(UnixListener, PathBuf, SocketCleanupGuard)> { - let mut last_bind_err: Option = None; - for _ in 0..SSE_SOCKET_BIND_MAX_ATTEMPTS { - let socket_path = build_sse_socket_path()?; - let socket_cleanup_guard = SocketCleanupGuard::new(socket_path.clone()); - let _ = std::fs::remove_file(&socket_path); - match UnixListener::bind(&socket_path) { - Ok(listener) => return Ok((listener, socket_path, socket_cleanup_guard)), - Err(err) - if matches!( - err.kind(), - std::io::ErrorKind::AlreadyExists | std::io::ErrorKind::AddrInUse - ) => - { - last_bind_err = Some(err); - continue; - } - Err(err) => { - return Err(err).context("failed to bind SSE unix socket"); - } - } - } - let err = last_bind_err.unwrap_or_else(|| { - std::io::Error::new( - std::io::ErrorKind::AddrInUse, - "failed to allocate a unique SSE socket path", - ) - }); - Err(err).context(format!( - "failed to bind SSE unix socket after {SSE_SOCKET_BIND_MAX_ATTEMPTS} attempts" - )) -} - fn eval_runner_cache_dir() -> PathBuf { let root = std::env::var_os("XDG_CACHE_HOME") .map(PathBuf::from) @@ -2708,13 +2520,7 @@ fn prepare_eval_runners_in_dir(cache_dir: &Path) -> Result<(PathBuf, PathBuf)> { } fn materialize_runner_script(cache_dir: &Path, file_name: &str, source: &str) -> Result { - let path = cache_dir.join(file_name); - let current = std::fs::read_to_string(&path).ok(); - if current.as_deref() != Some(source) { - std::fs::write(&path, source) - .with_context(|| format!("failed to write eval runner script {}", path.display()))?; - } - Ok(path) + js_runner::materialize_runner_script(cache_dir, file_name, source) } #[derive(Debug)] @@ -2849,57 +2655,6 @@ struct SseDependenciesEventData { files: Vec, } -async fn forward_stream( - stream: T, - name: &'static str, - tx: mpsc::UnboundedSender, -) -> Result<()> -where - T: tokio::io::AsyncRead + Unpin, -{ - let mut lines = BufReader::new(stream).lines(); - while let Some(line) = lines.next_line().await? { - let _ = tx.send(EvalEvent::Console { - stream: name.to_string(), - message: line, - }); - } - Ok(()) -} - -async fn read_sse_stream(stream: T, tx: mpsc::UnboundedSender) -> Result<()> -where - T: tokio::io::AsyncRead + Unpin, -{ - let mut lines = BufReader::new(stream).lines(); - let mut event: Option = None; - let mut data_lines: Vec = Vec::new(); - - while let Some(line) = lines.next_line().await? { - if line.is_empty() { - if event.is_some() || !data_lines.is_empty() { - let data = data_lines.join("\n"); - handle_sse_event(event.take(), data, &tx); - data_lines.clear(); - } - continue; - } - - if let Some(value) = line.strip_prefix("event:") { - event = Some(value.trim().to_string()); - } else if let Some(value) = line.strip_prefix("data:") { - data_lines.push(value.trim_start().to_string()); - } - } - - if event.is_some() || !data_lines.is_empty() { - let data = data_lines.join("\n"); - handle_sse_event(event.take(), data, &tx); - } - - Ok(()) -} - fn handle_sse_event(event: Option, data: String, tx: &mpsc::UnboundedSender) { let event_name = event.unwrap_or_default(); match event_name.as_str() { @@ -4446,10 +4201,11 @@ mod tests { assert!(message.contains("pnpm add -D vite-node")); } + #[cfg(unix)] #[test] fn build_sse_socket_path_is_unique_for_consecutive_calls() { - let first = build_sse_socket_path().expect("first socket path"); - let second = build_sse_socket_path().expect("second socket path"); + let first = runner_sse::build_sse_socket_path("bt-eval").expect("first socket path"); + let second = runner_sse::build_sse_socket_path("bt-eval").expect("second socket path"); assert_ne!(first, second); } diff --git a/src/http.rs b/src/http.rs index 4616ec8a..09bdceb8 100644 --- a/src/http.rs +++ b/src/http.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use anyhow::{Context, Result}; use reqwest::header::HeaderValue; use reqwest::{Client, ClientBuilder}; @@ -273,15 +275,31 @@ pub async fn put_signed_url( url: &str, body: Vec, content_encoding: Option<&str>, +) -> Result<()> { + let mut headers = HashMap::new(); + if let Some(encoding) = content_encoding { + headers.insert("Content-Encoding".to_string(), encoding.to_string()); + } + put_signed_url_with_headers(url, body, &headers).await +} + +pub async fn put_signed_url_with_headers( + url: &str, + body: Vec, + headers: &HashMap, ) -> Result<()> { let client = build_http_client(UPLOAD_HTTP_TIMEOUT).context("failed to build signed-url HTTP client")?; let mut request = client.put(url).body(body); - if let Some(encoding) = content_encoding { - request = request.header("Content-Encoding", encoding); + let mut has_azure_blob_type = false; + for (key, value) in headers { + if key.eq_ignore_ascii_case("x-ms-blob-type") { + has_azure_blob_type = true; + } + request = request.header(key.as_str(), value.as_str()); } - if url.contains(".blob.core.windows.net") { + if url.contains(".blob.core.windows.net") && !has_azure_blob_type { request = request.header("x-ms-blob-type", HeaderValue::from_static("BlockBlob")); } diff --git a/src/js_runner.rs b/src/js_runner.rs index cd4c94dc..31bba482 100644 --- a/src/js_runner.rs +++ b/src/js_runner.rs @@ -1,4 +1,4 @@ -use std::ffi::OsStr; +use std::ffi::{OsStr, OsString}; use std::path::{Path, PathBuf}; use std::process::Command; @@ -105,16 +105,20 @@ pub fn resolve_js_runner_command(runner: &str, files: &[PathBuf]) -> PathBuf { fn build_deno_command(deno_runner: &OsStr, runner_script: &Path, files: &[PathBuf]) -> Command { let mut command = Command::new(deno_runner); + command.args(deno_runner_args(runner_script, files)); command - .arg("run") - .arg("-A") - .arg("--node-modules-dir=auto") - .arg("--unstable-detect-cjs") - .arg(runner_script); - for file in files { - command.arg(file); - } - command +} + +pub fn deno_runner_args(runner_script: &Path, files: &[PathBuf]) -> Vec { + let mut args = vec![ + OsString::from("run"), + OsString::from("-A"), + OsString::from("--node-modules-dir=auto"), + OsString::from("--unstable-detect-cjs"), + runner_script.as_os_str().to_os_string(), + ]; + args.extend(files.iter().map(|file| file.as_os_str().to_os_string())); + args } fn is_path_like_runner(runner: &str) -> bool { @@ -122,7 +126,7 @@ fn is_path_like_runner(runner: &str) -> bool { path.is_absolute() || runner.contains('/') || runner.contains('\\') || runner.starts_with('.') } -fn is_deno_runner_path(runner: &Path) -> bool { +pub fn is_deno_runner_path(runner: &Path) -> bool { runner .file_name() .and_then(|value| value.to_str()) @@ -130,7 +134,7 @@ fn is_deno_runner_path(runner: &Path) -> bool { .unwrap_or(false) } -fn find_node_module_bin_for_files(binary: &str, files: &[PathBuf]) -> Option { +pub fn find_node_module_bin_for_files(binary: &str, files: &[PathBuf]) -> Option { for root in js_runner_search_roots(files) { if let Some(path) = find_node_module_bin(binary, &root) { return Some(path); @@ -176,7 +180,7 @@ fn find_node_module_bin(binary: &str, start: &Path) -> Option { None } -fn find_binary_in_path(candidates: &[&str]) -> Option { +pub fn find_binary_in_path(candidates: &[&str]) -> Option { let paths = std::env::var_os("PATH")?; for dir in std::env::split_paths(&paths) { for candidate in candidates { @@ -196,6 +200,15 @@ fn find_binary_in_path(candidates: &[&str]) -> Option { None } +pub fn runner_bin_name(runner_command: &Path) -> Option { + let name = runner_command.file_name()?.to_str()?.to_ascii_lowercase(); + Some(name.strip_suffix(".cmd").unwrap_or(&name).to_string()) +} + +pub fn is_ts_node_runner_path(runner_command: &Path) -> bool { + runner_bin_name(runner_command).is_some_and(|name| name == "ts-node" || name == "ts-node-esm") +} + #[cfg(windows)] fn with_windows_extensions(path: &Path) -> [PathBuf; 2] { [path.with_extension("exe"), path.with_extension("cmd")] diff --git a/src/main.rs b/src/main.rs index 0a34fb4d..8de7269f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -19,6 +19,7 @@ mod project_context; mod projects; mod prompts; mod python_runner; +mod runner_sse; mod scorers; mod self_update; mod setup; @@ -71,6 +72,7 @@ Projects & resources experiments Manage experiments Data & evaluation + datasets Manage datasets eval Run eval files sql Run SQL queries against Braintrust sync Synchronize project logs between Braintrust and local NDJSON files diff --git a/src/runner_sse.rs b/src/runner_sse.rs new file mode 100644 index 00000000..b0b2b593 --- /dev/null +++ b/src/runner_sse.rs @@ -0,0 +1,293 @@ +use std::pin::Pin; +use std::process::ExitStatus; +use std::sync::atomic::{AtomicBool, Ordering}; + +use anyhow::{Context, Result}; +use std::future::Future; +use tokio::io::{AsyncBufReadExt, BufReader}; +use tokio::sync::mpsc; + +#[cfg(unix)] +use std::path::PathBuf; +#[cfg(unix)] +use std::sync::atomic::AtomicU64; +#[cfg(unix)] +use std::time::{SystemTime, UNIX_EPOCH}; +#[cfg(not(unix))] +use tokio::net::TcpListener; +#[cfg(unix)] +use tokio::net::UnixListener; + +#[cfg(unix)] +const SOCKET_BIND_MAX_ATTEMPTS: u8 = 16; +#[cfg(unix)] +static SOCKET_COUNTER: AtomicU64 = AtomicU64::new(0); + +pub(crate) enum SseListener { + #[cfg(unix)] + Unix(UnixListener), + #[cfg(not(unix))] + Tcp(TcpListener), +} + +pub(crate) struct SseListenerGuard { + endpoint: SseEndpoint, + #[cfg(unix)] + _socket_cleanup_guard: SocketCleanupGuard, +} + +enum SseEndpoint { + #[cfg(unix)] + Unix(PathBuf), + #[cfg(not(unix))] + Tcp(std::net::SocketAddr), +} + +#[cfg(unix)] +struct SocketCleanupGuard { + path: PathBuf, +} + +#[cfg(unix)] +impl SocketCleanupGuard { + fn new(path: PathBuf) -> Self { + Self { path } + } +} + +#[cfg(unix)] +impl Drop for SocketCleanupGuard { + fn drop(&mut self) { + let _ = std::fs::remove_file(&self.path); + } +} + +impl SseListenerGuard { + pub(crate) fn env<'a>(&self, socket_var: &'a str, addr_var: &'a str) -> (&'a str, String) { + #[cfg(unix)] + let _ = addr_var; + #[cfg(not(unix))] + let _ = socket_var; + match &self.endpoint { + #[cfg(unix)] + SseEndpoint::Unix(path) => (socket_var, path.to_string_lossy().to_string()), + #[cfg(not(unix))] + SseEndpoint::Tcp(addr) => (addr_var, addr.to_string()), + } + } +} + +pub(crate) fn bind_sse_listener(prefix: &str) -> Result<(SseListener, SseListenerGuard)> { + #[cfg(unix)] + { + bind_unix_sse_listener(prefix) + } + + #[cfg(not(unix))] + { + let _ = prefix; + bind_tcp_sse_listener() + } +} + +#[cfg(unix)] +fn bind_unix_sse_listener(prefix: &str) -> Result<(SseListener, SseListenerGuard)> { + let mut last_bind_err: Option = None; + for _ in 0..SOCKET_BIND_MAX_ATTEMPTS { + let socket_path = build_sse_socket_path(prefix)?; + let socket_cleanup_guard = SocketCleanupGuard::new(socket_path.clone()); + let _ = std::fs::remove_file(&socket_path); + match UnixListener::bind(&socket_path) { + Ok(listener) => { + return Ok(( + SseListener::Unix(listener), + SseListenerGuard { + endpoint: SseEndpoint::Unix(socket_path), + _socket_cleanup_guard: socket_cleanup_guard, + }, + )) + } + Err(err) + if matches!( + err.kind(), + std::io::ErrorKind::AlreadyExists | std::io::ErrorKind::AddrInUse + ) => + { + last_bind_err = Some(err); + continue; + } + Err(err) => { + return Err(err).context("failed to bind SSE unix socket"); + } + } + } + let err = last_bind_err.unwrap_or_else(|| { + std::io::Error::new( + std::io::ErrorKind::AddrInUse, + "failed to allocate a unique SSE socket path", + ) + }); + Err(err).context(format!( + "failed to bind SSE unix socket after {SOCKET_BIND_MAX_ATTEMPTS} attempts" + )) +} + +#[cfg(not(unix))] +fn bind_tcp_sse_listener() -> Result<(SseListener, SseListenerGuard)> { + let std_listener = + std::net::TcpListener::bind(("127.0.0.1", 0)).context("failed to bind SSE TCP listener")?; + std_listener + .set_nonblocking(true) + .context("failed to configure SSE TCP listener")?; + let addr = std_listener + .local_addr() + .context("failed to read SSE TCP listener address")?; + let listener = + TcpListener::from_std(std_listener).context("failed to create SSE TCP listener")?; + Ok(( + SseListener::Tcp(listener), + SseListenerGuard { + endpoint: SseEndpoint::Tcp(addr), + }, + )) +} + +#[cfg(unix)] +pub(crate) fn build_sse_socket_path(prefix: &str) -> Result { + let pid = std::process::id(); + let serial = SOCKET_COUNTER.fetch_add(1, Ordering::Relaxed); + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .context("failed to read system time")? + .as_nanos(); + Ok(std::env::temp_dir().join(format!("{prefix}-{pid}-{now}-{serial}.sock"))) +} + +pub(crate) async fn accept_and_read_sse_stream( + listener: SseListener, + on_connected: C, + on_event: F, +) -> Result<()> +where + C: FnOnce(), + F: FnMut(Option, String), +{ + match listener { + #[cfg(unix)] + SseListener::Unix(listener) => { + let (stream, _) = listener + .accept() + .await + .context("failed to accept SSE unix socket connection")?; + on_connected(); + read_sse_stream(stream, on_event).await + } + #[cfg(not(unix))] + SseListener::Tcp(listener) => { + let (stream, _) = listener + .accept() + .await + .context("failed to accept SSE TCP connection")?; + on_connected(); + read_sse_stream(stream, on_event).await + } + } +} + +pub(crate) async fn forward_stream( + stream: T, + name: &'static str, + mut on_line: F, +) -> Result<()> +where + T: tokio::io::AsyncRead + Unpin, + F: FnMut(&'static str, String), +{ + let mut lines = BufReader::new(stream).lines(); + while let Some(line) = lines.next_line().await? { + on_line(name, line); + } + Ok(()) +} + +pub(crate) async fn read_sse_stream(stream: T, mut on_event: F) -> Result<()> +where + T: tokio::io::AsyncRead + Unpin, + F: FnMut(Option, String), +{ + let mut lines = BufReader::new(stream).lines(); + let mut event: Option = None; + let mut data_lines: Vec = Vec::new(); + + while let Some(line) = lines.next_line().await? { + if line.is_empty() { + if event.is_some() || !data_lines.is_empty() { + let data = data_lines.join("\n"); + on_event(event.take(), data); + data_lines.clear(); + } + continue; + } + + if let Some(value) = line.strip_prefix("event:") { + event = Some(value.trim().to_string()); + } else if let Some(value) = line.strip_prefix("data:") { + data_lines.push(value.trim_start().to_string()); + } + } + + if event.is_some() || !data_lines.is_empty() { + let data = data_lines.join("\n"); + on_event(event.take(), data); + } + + Ok(()) +} + +pub(crate) async fn drive_runner_events( + mut rx: mpsc::UnboundedReceiver, + mut wait: Pin> + Send + '_>>, + sse_task: &mut tokio::task::JoinHandle<()>, + sse_connected: &AtomicBool, + missing_status_message: &'static str, + mut on_event: F, +) -> Result +where + F: FnMut(E), +{ + let mut status: Option = None; + + loop { + tokio::select! { + event = rx.recv() => { + match event { + Some(event) => on_event(event), + None => { + if status.is_none() { + status = Some(wait.as_mut().await?); + abort_unconnected_sse(sse_task, sse_connected); + } + break; + } + } + } + wait_result = wait.as_mut(), if status.is_none() => { + status = Some(wait_result?); + abort_unconnected_sse(sse_task, sse_connected); + } + } + + if status.is_some() && rx.is_closed() { + break; + } + } + + let _ = sse_task.await; + status.context(missing_status_message) +} + +fn abort_unconnected_sse(sse_task: &mut tokio::task::JoinHandle<()>, sse_connected: &AtomicBool) { + if !sse_connected.load(Ordering::Relaxed) { + sse_task.abort(); + } +} diff --git a/src/sync.rs b/src/sync.rs index 9bf07f82..389b4739 100644 --- a/src/sync.rs +++ b/src/sync.rs @@ -28,6 +28,8 @@ use crate::projects::api::{create_project, list_projects, Project}; use crate::ui::{animations_enabled, fuzzy_select, is_quiet}; use crate::utils::parse_duration_to_seconds; +pub(crate) mod discovery; + const STATE_SCHEMA_VERSION: u32 = 1; const DEFAULT_PULL_LIMIT: usize = 100; const DEFAULT_PAGE_SIZE: usize = 1000; @@ -49,6 +51,14 @@ pub(crate) fn default_workers() -> usize { .unwrap_or(DEFAULT_WORKERS_FALLBACK) } +#[derive(Debug, Clone)] +pub(crate) struct SyncPushFileArgs { + pub object_ref: String, + pub input: PathBuf, + pub root: PathBuf, + pub fresh: bool, +} + #[derive(Debug, Clone, Args)] #[command(after_help = "\ Examples: @@ -420,6 +430,8 @@ struct ResolvedPushDestination { object: ObjectRef, object_ref: String, project_id: String, + project_name: String, + object_name: String, run_id: Option, } @@ -428,13 +440,17 @@ struct ResolvedDestination { object: ObjectRef, object_ref: String, project_id: String, + project_name: String, + object_name: String, run_id: Option, } #[derive(Debug, Clone)] struct ResolvedNamedObjectTarget { object_id: String, + object_name: String, project_id: String, + project_name: String, } #[derive(Debug, Clone, Copy)] @@ -577,6 +593,36 @@ pub async fn run(base: BaseArgs, args: SyncArgs) -> Result<()> { } } +pub(crate) async fn push_jsonl_file(base: BaseArgs, args: SyncPushFileArgs) -> Result<()> { + let json_output = base.json; + let ctx = login(&base).await?; + let client = ApiClient::new(&ctx)?; + let project = base.project.clone().or_else(|| { + crate::config::configured_project_for_context(&base, ctx.login.org_name().as_deref()) + }); + + run_push( + json_output, + &ctx, + &client, + project.as_deref(), + PushArgs { + object_ref: args.object_ref, + input: Some(args.input), + filter: None, + traces: None, + spans: None, + page_size: DEFAULT_PAGE_SIZE, + fresh: args.fresh, + root: args.root, + workers: default_workers(), + max_batch_bytes: PUSH_BATCH_MAX_INPUT_BYTES, + max_in_flight_bytes: PUSH_MAX_IN_FLIGHT_INPUT_BYTES, + }, + ) + .await +} + async fn run_pull( json_output: bool, verbose: bool, @@ -1494,8 +1540,7 @@ async fn process_trace_chunk( let serialized = rows .iter() .map(|row| { - let line = - serde_json::to_string(row).context("failed to serialize trace row")?; + let line = serialize_jsonl_value(row)?; bytes_written += (line.len() + 1) as u64; Result::::Ok(line) }) @@ -1600,6 +1645,7 @@ async fn run_push( args: PushArgs, ) -> Result<()> { let destination = resolve_push_destination(client, &args.object_ref, project_selector).await?; + let object_url = push_destination_url(&ctx.app_url, client.org_name(), &destination); let object = destination.object.clone(); let (scope, limit) = resolve_push_scope_and_limit(args.traces, args.spans)?; @@ -1674,6 +1720,7 @@ async fn run_push( "status": "completed", "message": "already completed for this spec", "source_path": state.source_path, + "object_url": object_url, "items_done": state.items_done, "pages_done": state.pages_done }))? @@ -1683,6 +1730,7 @@ async fn run_push( "Sync already completed for this spec. input={} items={} pages={}", state.source_path, state.items_done, state.pages_done ); + println!(" URL: {object_url}"); } return Ok(()); } @@ -1976,6 +2024,7 @@ async fn run_push( "status": "interrupted", "spec_dir": spec_dir, "input_path": input_path, + "object_url": object_url, "rows_uploaded": state.items_done, "pages_done": state.pages_done, "bytes_sent": state.bytes_sent, @@ -1991,6 +2040,7 @@ async fn run_push( format_u64_commas(state.bytes_sent) ); println!(" Resume: rerun the same command (use --fresh to restart)"); + println!(" URL: {object_url}"); } return Ok(()); } @@ -2015,6 +2065,7 @@ async fn run_push( "status": "completed", "spec_dir": spec_dir, "input_path": input_path, + "object_url": object_url, "rows_uploaded": state.items_done, "pages_done": state.pages_done, "bytes_sent": state.bytes_sent @@ -2028,6 +2079,7 @@ async fn run_push( let spans_per_sec = spans_done as f64 / elapsed_secs as f64; let bytes_per_sec = state.bytes_sent as f64 / elapsed_secs as f64; println!("Push complete"); + println!(" URL: {object_url}"); println!(" Input: {}", input_path.display()); println!(" Time: {}", format_duration(elapsed_secs)); println!(" Traces: {}", format_usize_commas(traces_done)); @@ -2048,6 +2100,35 @@ async fn run_push( Ok(()) } +fn push_destination_url( + app_url: &str, + org_name: &str, + destination: &ResolvedPushDestination, +) -> String { + let project_name = if destination.project_name.trim().is_empty() { + destination.project_id.as_str() + } else { + destination.project_name.as_str() + }; + let object_name = if destination.object_name.trim().is_empty() { + destination.object.object_name.as_str() + } else { + destination.object_name.as_str() + }; + let path = match destination.object.object_type { + ObjectType::ProjectLogs => "logs".to_string(), + ObjectType::Experiment => format!("experiments/{}", encode(object_name)), + ObjectType::Dataset => format!("datasets/{}", encode(object_name)), + }; + format!( + "{}/app/{}/p/{}/{}", + app_url.trim_end_matches('/'), + encode(org_name), + encode(project_name), + path + ) +} + fn run_status(json_output: bool, args: StatusArgs) -> Result<()> { let object = parse_object_ref(&args.object_ref)?; let (scope, limit) = resolve_status_scope_and_limit(args.traces, args.spans)?; @@ -2124,16 +2205,21 @@ fn run_status(json_output: bool, args: StatusArgs) -> Result<()> { Ok(()) } -async fn execute_btql_query( +async fn execute_btql_request( client: &ApiClient, ctx: &LoginContext, - query: &str, + query: &Q, + query_source: &str, btql_retry_tracker: Option>, -) -> Result { +) -> Result +where + Q: Serialize + ?Sized, + T: DeserializeOwned, +{ let body = json!({ "query": query, "fmt": "json", - "query_source": "bt_sync_9f4b1e6d7c2a4a7b8d4f9a6c2b1e7f3d", + "query_source": query_source, }); let org_name = ctx.login.org_name().unwrap_or_default(); let client = client.clone(); @@ -2168,7 +2254,7 @@ async fn execute_btql_query( Ok(response) => { let status = response.status(); if status.is_success() { - return response.json::().await.map_err(|err| { + return response.json::().await.map_err(|err| { BackoffError::permanent(anyhow!("failed to parse BTQL response: {err}")) }); } @@ -2215,6 +2301,31 @@ async fn execute_btql_query( }) } +async fn execute_btql_query( + client: &ApiClient, + ctx: &LoginContext, + query: &str, + btql_retry_tracker: Option>, +) -> Result { + execute_btql_request( + client, + ctx, + query, + "bt_sync_9f4b1e6d7c2a4a7b8d4f9a6c2b1e7f3d", + btql_retry_tracker, + ) + .await +} + +async fn execute_btql_json_query( + client: &ApiClient, + ctx: &LoginContext, + query: &Value, + query_source: &str, +) -> Result { + execute_btql_request(client, ctx, query, query_source, None).await +} + async fn execute_btql_query_timed( client: &ApiClient, ctx: &LoginContext, @@ -3093,7 +3204,9 @@ async fn resolve_destination( Ok(ResolvedDestination { object_ref: format!("project_logs:{}", project.id), object, - project_id: project.id, + project_id: project.id.clone(), + project_name: project.name.clone(), + object_name: project.name, run_id: None, }) } @@ -3127,6 +3240,8 @@ async fn resolve_destination( object_ref: format!("experiment:{}", resolved.object_id), object, project_id: resolved.project_id, + project_name: resolved.project_name, + object_name: resolved.object_name, run_id, }) } @@ -3151,6 +3266,8 @@ async fn resolve_destination( object_ref: format!("dataset:{}", resolved.object_id), object, project_id: resolved.project_id, + project_name: resolved.project_name, + object_name: resolved.object_name, run_id: None, }) } @@ -3207,7 +3324,9 @@ async fn resolve_named_object_target( )?; return Ok(ResolvedNamedObjectTarget { object_id: object.id.clone(), + object_name: object.name.clone(), project_id: project.id.clone(), + project_name: project.name.clone(), }); } @@ -3223,7 +3342,9 @@ async fn resolve_named_object_target( if let Some(object) = objects.iter().find(|value| value.id == object_selector) { return Ok(ResolvedNamedObjectTarget { object_id: object.id.clone(), + object_name: object.name.clone(), project_id: project.id.clone(), + project_name: project.name.clone(), }); } } @@ -3273,7 +3394,9 @@ async fn resolve_push_experiment_target( ) { return Ok(ResolvedNamedObjectTarget { object_id: object.id.clone(), - project_id: project.id, + object_name: object.name.clone(), + project_id: project.id.clone(), + project_name: project.name.clone(), }); } @@ -3300,7 +3423,9 @@ async fn resolve_push_experiment_target( Ok(ResolvedNamedObjectTarget { object_id: created.id, + object_name: created.name, project_id: project.id, + project_name: project.name, }) } @@ -3336,7 +3461,9 @@ async fn resolve_push_dataset_target( ) { return Ok(ResolvedNamedObjectTarget { object_id: object.id.clone(), - project_id: project.id, + object_name: object.name.clone(), + project_id: project.id.clone(), + project_name: project.name.clone(), }); } @@ -3363,7 +3490,9 @@ async fn resolve_push_dataset_target( Ok(ResolvedNamedObjectTarget { object_id: created.id, + object_name: created.name, project_id: project.id, + project_name: project.name, }) } @@ -3411,6 +3540,8 @@ async fn resolve_push_destination( object: resolved.object, object_ref: resolved.object_ref, project_id: resolved.project_id, + project_name: resolved.project_name, + object_name: resolved.object_name, run_id: resolved.run_id, }) } @@ -3575,13 +3706,26 @@ fn sanitize_segment(value: &str) -> String { } } -fn spec_dir(root: &Path, object: &ObjectRef, hash: &str) -> PathBuf { +pub(crate) fn artifact_base_dir(root: &Path, object_type: &str, object_name: &str) -> PathBuf { let object_key = format!( "{}_{}", - sanitize_segment(object.object_type.as_str()), - sanitize_segment(&object.object_name) + sanitize_segment(object_type), + sanitize_segment(object_name) ); - root.join(object_key).join(&hash[..12]) + root.join(object_key) +} + +pub(crate) fn artifact_spec_dir( + root: &Path, + object_type: &str, + object_name: &str, + hash: &str, +) -> PathBuf { + artifact_base_dir(root, object_type, object_name).join(&hash[..12]) +} + +fn spec_dir(root: &Path, object: &ObjectRef, hash: &str) -> PathBuf { + artifact_spec_dir(root, object.object_type.as_str(), &object.object_name, hash) } fn legacy_spec_dir( @@ -3630,8 +3774,8 @@ fn resolve_spec_dir( } } -fn spec_hash(spec: &SyncSpec) -> Result { - let canonical = serde_json::to_vec(spec).context("failed to serialize sync spec")?; +pub(crate) fn stable_spec_hash(spec: &T) -> Result { + let canonical = serde_json::to_vec(spec).context("failed to serialize spec")?; let mut hasher = Sha256::new(); hasher.update(&canonical); let digest = hasher.finalize(); @@ -3642,7 +3786,11 @@ fn spec_hash(spec: &SyncSpec) -> Result { Ok(out) } -fn epoch_seconds() -> u64 { +fn spec_hash(spec: &SyncSpec) -> Result { + stable_spec_hash(spec) +} + +pub(crate) fn epoch_seconds() -> u64 { SystemTime::now() .duration_since(UNIX_EPOCH) .map(|d| d.as_secs()) @@ -3888,8 +4036,73 @@ fn open_jsonl_part_writer(base_dir: &Path, append: bool) -> Result Result> { + let file = File::open(path).with_context(|| format!("failed to open {}", path.display()))?; + let reader = BufReader::new(file); + let mut values = Vec::new(); + for (index, line) in reader.lines().enumerate() { + let line = line.with_context(|| { + format!("failed to read line {} from {}", index + 1, path.display()) + })?; + let trimmed = line.trim(); + if trimmed.is_empty() { + continue; + } + values.push(serde_json::from_str(trimmed).with_context(|| { + format!( + "failed to parse JSON on line {} from {}", + index + 1, + path.display() + ) + })?); + } + Ok(values) +} + +fn serialize_jsonl_value(value: &T) -> Result { + serde_json::to_string(value).context("failed to serialize row to JSONL") +} + +pub(crate) fn write_jsonl_value( + writer: &mut dyn Write, + value: &T, +) -> Result { + let encoded = serialize_jsonl_value(value)?; + writer + .write_all(encoded.as_bytes()) + .context("failed to write JSONL row")?; + writer + .write_all(b"\n") + .context("failed to write JSONL newline")?; + Ok(encoded.len() + 1) +} + +pub(crate) fn create_jsonl_file_writer(path: &Path) -> Result> { + if let Some(parent) = path.parent().filter(|p| !p.as_os_str().is_empty()) { + fs::create_dir_all(parent) + .with_context(|| format!("failed to create {}", parent.display()))?; + } + Ok(BufWriter::new(File::create(path).with_context(|| { + format!("failed to create {}", path.display()) + })?)) +} + +#[cfg(test)] +pub(crate) fn write_jsonl_values(out: Option<&Path>, values: &[T]) -> Result<()> { + let mut writer: Box = if let Some(path) = out { + Box::new(create_jsonl_file_writer(path)?) + } else { + Box::new(BufWriter::new(std::io::stdout())) + }; + + for value in values { + write_jsonl_value(writer.as_mut(), value)?; + } + writer.flush().context("failed to flush JSONL output") +} + fn write_jsonl_row(writer: &mut JsonlPartWriter, row: &Map) -> Result { - let encoded = serde_json::to_string(row).context("failed to serialize row to JSONL")?; + let encoded = serialize_jsonl_value(row)?; writer .write_line(&encoded) .context("failed to write JSONL row") @@ -4348,7 +4561,7 @@ fn value_as_string(value: Option<&Value>) -> Option { } } -fn write_json_atomic(path: &Path, value: &T) -> Result<()> { +pub(crate) fn write_json_atomic(path: &Path, value: &T) -> Result<()> { let parent = path .parent() .ok_or_else(|| anyhow!("path has no parent: {}", path.display()))?; @@ -4367,7 +4580,7 @@ fn write_json_atomic(path: &Path, value: &T) -> Result<()> { Ok(()) } -fn read_json_file(path: &Path) -> Result { +pub(crate) fn read_json_file(path: &Path) -> Result { let bytes = fs::read(path).with_context(|| format!("failed to read {}", path.display()))?; serde_json::from_slice(&bytes).with_context(|| format!("failed to parse {}", path.display())) } @@ -4491,6 +4704,87 @@ fn spinner_bar(message: &str) -> ProgressBar { mod tests { use super::*; + #[test] + fn write_jsonl_value_serializes_one_line_and_reports_bytes() -> Result<()> { + let mut output = Vec::new(); + + let bytes = write_jsonl_value(&mut output, &json!({ "id": "row-1" }))?; + + assert_eq!(bytes, output.len()); + assert_eq!(String::from_utf8(output)?, "{\"id\":\"row-1\"}\n"); + Ok(()) + } + + #[test] + fn read_jsonl_values_skips_blank_lines() -> Result<()> { + let unique = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_nanos()) + .unwrap_or_default(); + let path = std::env::temp_dir().join(format!( + "bt-sync-read-jsonl-values-{}-{}.jsonl", + std::process::id(), + unique + )); + + fs::write(&path, "{\"id\":\"row-1\"}\n\n{\"id\":\"row-2\"}\n")?; + let values = read_jsonl_values(&path)?; + + assert_eq!( + values, + vec![json!({ "id": "row-1" }), json!({ "id": "row-2" })] + ); + let _ = fs::remove_file(&path); + Ok(()) + } + + #[test] + fn write_jsonl_values_writes_file() -> Result<()> { + let unique = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_nanos()) + .unwrap_or_default(); + let path = std::env::temp_dir().join(format!( + "bt-sync-write-jsonl-values-{}-{}.jsonl", + std::process::id(), + unique + )); + + write_jsonl_values( + Some(&path), + &[json!({ "id": "row-1" }), json!({ "id": "row-2" })], + )?; + + let content = fs::read_to_string(&path)?; + assert_eq!(content, "{\"id\":\"row-1\"}\n{\"id\":\"row-2\"}\n"); + let _ = fs::remove_file(&path); + Ok(()) + } + + #[test] + fn push_destination_url_links_to_dataset_object() { + let destination = ResolvedPushDestination { + object: ObjectRef { + object_type: ObjectType::Dataset, + object_name: "dataset-id".to_string(), + }, + object_ref: "dataset:dataset-id".to_string(), + project_id: "project-id".to_string(), + project_name: "Facet Optimizer".to_string(), + object_name: "Loop Facet Ground Truth".to_string(), + run_id: None, + }; + + assert_eq!( + push_destination_url( + "https://www.braintrust.dev/", + "braintrustdata.com", + &destination + ), + "https://www.braintrust.dev/app/braintrustdata.com/p/Facet%20Optimizer/datasets/Loop%20Facet%20Ground%20Truth" + ); + } + #[test] fn push_checkpoint_line_offset_advances_only_after_commit() { let mut state = diff --git a/src/sync/discovery.rs b/src/sync/discovery.rs new file mode 100644 index 00000000..77581a76 --- /dev/null +++ b/src/sync/discovery.rs @@ -0,0 +1,437 @@ +use std::collections::{HashMap, HashSet}; + +use anyhow::Result; +use serde::Deserialize; +use serde_json::{json, Map, Value}; + +use crate::auth::LoginContext; +use crate::http::ApiClient; +use crate::sync::execute_btql_json_query; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum ProjectLogRefScope { + Trace, + Span, +} + +#[derive(Debug, Clone, PartialEq)] +pub(crate) struct ProjectLogRef { + pub(crate) root_span_id: String, + pub(crate) id: Option, + pub(crate) origin: Option, + origin_is_root: bool, + origin_created: Option, +} + +impl ProjectLogRef { + pub(crate) fn to_value(&self) -> Value { + let mut reference = Map::new(); + reference.insert( + "root_span_id".to_string(), + Value::String(self.root_span_id.clone()), + ); + if let Some(id) = self.id.as_deref() { + reference.insert("id".to_string(), Value::String(id.to_string())); + } + if let Some(origin) = self.origin.as_ref() { + reference.insert("origin".to_string(), origin.clone()); + } + Value::Object(reference) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) struct ProjectLogRefDiscoveryResult { + pub(crate) refs: usize, + pub(crate) pages: usize, +} + +#[derive(Debug, Deserialize)] +struct DiscoveryBtqlResponse { + data: Vec>, + #[serde(default)] + cursor: Option, +} + +pub(crate) async fn discover_project_log_refs( + client: &ApiClient, + ctx: &LoginContext, + project_id: &str, + filter: Option<&Value>, + scope: ProjectLogRefScope, + target: usize, + page_size: usize, + mut on_ref: F, +) -> Result +where + F: FnMut(ProjectLogRef) -> Result<()>, +{ + let mut seen = HashSet::new(); + let mut trace_roots = Vec::new(); + let mut trace_refs_by_root_span_id = HashMap::new(); + let mut cursor: Option = None; + let mut pages = 0usize; + while discovered_ref_count(scope, seen.len(), trace_roots.len()) < target { + let remaining = target - discovered_ref_count(scope, seen.len(), trace_roots.len()); + let limit = discovery_page_limit(scope, remaining, page_size); + let query = + build_project_log_ref_query(project_id, filter, limit, cursor.as_deref(), scope); + let response = execute_discovery_btql(client, ctx, &query).await?; + let row_count = response.data.len(); + + for row in response.data { + if matches!(scope, ProjectLogRefScope::Span) && seen.len() >= target { + break; + } + let Some(reference) = project_log_ref_from_row(project_id, &row, scope) else { + continue; + }; + match scope { + ProjectLogRefScope::Span => { + if seen.insert(project_log_ref_key(&reference)) { + on_ref(reference)?; + } + } + ProjectLogRefScope::Trace => { + let root_span_id = reference.root_span_id.clone(); + if !trace_refs_by_root_span_id.contains_key(&root_span_id) { + if trace_roots.len() >= target { + continue; + } + trace_roots.push(root_span_id.clone()); + trace_refs_by_root_span_id.insert(root_span_id, reference); + continue; + } + let should_replace = trace_refs_by_root_span_id + .get(&root_span_id) + .is_none_or(|current| better_trace_origin_ref(current, &reference)); + if should_replace { + trace_refs_by_root_span_id.insert(root_span_id, reference); + } + } + } + } + + pages += 1; + cursor = response.cursor.filter(|c| !c.is_empty()); + if row_count == 0 || cursor.is_none() { + break; + } + } + if matches!(scope, ProjectLogRefScope::Trace) { + for root_span_id in &trace_roots { + if let Some(reference) = trace_refs_by_root_span_id.remove(root_span_id) { + on_ref(reference)?; + } + } + } + + Ok(ProjectLogRefDiscoveryResult { + refs: discovered_ref_count(scope, seen.len(), trace_roots.len()), + pages, + }) +} + +fn discovered_ref_count(scope: ProjectLogRefScope, span_refs: usize, trace_refs: usize) -> usize { + match scope { + ProjectLogRefScope::Trace => trace_refs, + ProjectLogRefScope::Span => span_refs, + } +} + +fn discovery_page_limit(scope: ProjectLogRefScope, remaining: usize, page_size: usize) -> usize { + match scope { + ProjectLogRefScope::Trace => page_size.min(1000), + ProjectLogRefScope::Span => remaining.min(page_size).min(1000), + } +} + +async fn execute_discovery_btql( + client: &ApiClient, + ctx: &LoginContext, + query: &Value, +) -> Result { + execute_btql_json_query(client, ctx, query, "bt_sync_discovery").await +} + +fn build_project_log_ref_query( + project_id: &str, + filter: Option<&Value>, + page_size: usize, + cursor: Option<&str>, + scope: ProjectLogRefScope, +) -> Value { + let select = match scope { + ProjectLogRefScope::Trace => vec![ + btql_select_field("root_span_id"), + btql_select_field("id"), + btql_select_field("is_root"), + btql_select_field("created"), + btql_select_field("_xact_id"), + ], + ProjectLogRefScope::Span => { + vec![ + btql_select_field("root_span_id"), + btql_select_field("id"), + btql_select_field("created"), + btql_select_field("_xact_id"), + ] + } + }; + + let mut query = json!({ + "select": select, + "from": { + "op": "function", + "name": { "op": "ident", "name": ["project_logs"] }, + "args": [{ "op": "literal", "value": project_id }], + "shape": "spans" + }, + "limit": page_size, + "sort": [{ + "expr": { "op": "ident", "name": ["_pagination_key"] }, + "dir": "desc" + }] + }); + + if let Some(filter_expr) = filter { + query["filter"] = filter_expr.clone(); + } + if let Some(c) = cursor { + query["cursor"] = Value::String(c.to_string()); + } + query +} + +fn project_log_ref_from_row( + project_id: &str, + row: &Map, + scope: ProjectLogRefScope, +) -> Option { + let root_span_id = row_string(row, "root_span_id")?; + match scope { + ProjectLogRefScope::Trace => Some(ProjectLogRef { + root_span_id, + id: None, + origin: project_log_origin_from_row(project_id, row), + origin_is_root: row_bool(row, "is_root"), + origin_created: row_string(row, "created"), + }), + ProjectLogRefScope::Span => Some(ProjectLogRef { + root_span_id, + id: Some(row_string(row, "id")?), + origin: project_log_origin_from_row(project_id, row), + origin_is_root: row_bool(row, "is_root"), + origin_created: row_string(row, "created"), + }), + } +} + +fn btql_select_field(field: &str) -> Value { + json!({ + "alias": field, + "expr": { "op": "ident", "name": [field] } + }) +} + +fn project_log_ref_key(reference: &ProjectLogRef) -> (String, Option) { + (reference.root_span_id.clone(), reference.id.clone()) +} + +fn row_string(row: &Map, key: &str) -> Option { + row.get(key) + .and_then(Value::as_str) + .map(ToString::to_string) +} + +fn row_bool(row: &Map, key: &str) -> bool { + row.get(key).and_then(Value::as_bool).unwrap_or(false) +} + +fn better_trace_origin_ref(current: &ProjectLogRef, candidate: &ProjectLogRef) -> bool { + match (current.origin_is_root, candidate.origin_is_root) { + (false, true) => return true, + (true, false) => return false, + _ => {} + } + + match (¤t.origin_created, &candidate.origin_created) { + (Some(current_created), Some(candidate_created)) => candidate_created < current_created, + (None, Some(_)) => true, + _ => false, + } +} + +fn project_log_origin_from_row(project_id: &str, row: &Map) -> Option { + let row_id = row_string(row, "id")?; + Some(object_origin( + "project_logs", + project_id, + &row_id, + row.get("created").and_then(Value::as_str), + row.get("_xact_id").and_then(Value::as_str), + )) +} + +fn object_origin( + object_type: &str, + object_id: &str, + row_id: &str, + created: Option<&str>, + xact_id: Option<&str>, +) -> Value { + let mut origin = Map::from_iter([ + ( + "object_type".to_string(), + Value::String(object_type.to_string()), + ), + ( + "object_id".to_string(), + Value::String(object_id.to_string()), + ), + ("id".to_string(), Value::String(row_id.to_string())), + ]); + if let Some(created) = created { + origin.insert("created".to_string(), Value::String(created.to_string())); + } + if let Some(xact_id) = xact_id { + origin.insert("_xact_id".to_string(), Value::String(xact_id.to_string())); + } + Value::Object(origin) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn project_log_ref_from_row_uses_trace_scope() { + let row = Map::from_iter([ + ("root_span_id".to_string(), json!("root-1")), + ("id".to_string(), json!("span-1")), + ]); + + assert_eq!( + project_log_ref_from_row("project-1", &row, ProjectLogRefScope::Trace), + Some(ProjectLogRef { + root_span_id: "root-1".to_string(), + id: None, + origin: Some(json!({ + "object_type": "project_logs", + "object_id": "project-1", + "id": "span-1" + })), + origin_is_root: false, + origin_created: None, + }) + ); + } + + #[test] + fn project_log_ref_from_row_uses_span_scope() { + let row = Map::from_iter([ + ("root_span_id".to_string(), json!("root-1")), + ("id".to_string(), json!("span-1")), + ]); + + assert_eq!( + project_log_ref_from_row("project-1", &row, ProjectLogRefScope::Span), + Some(ProjectLogRef { + root_span_id: "root-1".to_string(), + id: Some("span-1".to_string()), + origin: Some(json!({ + "object_type": "project_logs", + "object_id": "project-1", + "id": "span-1" + })), + origin_is_root: false, + origin_created: None, + }) + ); + } + + #[test] + fn span_scope_page_limit_uses_remaining_target() { + assert_eq!(discovery_page_limit(ProjectLogRefScope::Span, 3, 1000), 3); + } + + #[test] + fn trace_scope_page_limit_keeps_full_page_for_dedupe() { + assert_eq!( + discovery_page_limit(ProjectLogRefScope::Trace, 3, 1000), + 1000 + ); + } + + #[test] + fn project_log_origin_from_row_includes_optional_position_fields() { + let row = Map::from_iter([ + ("id".to_string(), json!("row-1")), + ("created".to_string(), json!("2026-01-01T00:00:00Z")), + ("_xact_id".to_string(), json!("100")), + ]); + + assert_eq!( + project_log_origin_from_row("project-1", &row), + Some(json!({ + "object_type": "project_logs", + "object_id": "project-1", + "id": "row-1", + "created": "2026-01-01T00:00:00Z", + "_xact_id": "100" + })) + ); + } + + #[test] + fn object_origin_supports_arbitrary_source_objects() { + assert_eq!( + object_origin("dataset", "dataset-1", "row-1", None, None), + json!({ + "object_type": "dataset", + "object_id": "dataset-1", + "id": "row-1" + }) + ); + } + + #[test] + fn trace_origin_ref_prefers_is_root_over_earliest_created() { + let current = ProjectLogRef { + root_span_id: "root-1".to_string(), + id: None, + origin: Some(json!({ "id": "earliest" })), + origin_is_root: false, + origin_created: Some("2026-01-01T00:00:00Z".to_string()), + }; + let candidate = ProjectLogRef { + root_span_id: "root-1".to_string(), + id: None, + origin: Some(json!({ "id": "root" })), + origin_is_root: true, + origin_created: Some("2026-01-02T00:00:00Z".to_string()), + }; + + assert!(better_trace_origin_ref(¤t, &candidate)); + } + + #[test] + fn trace_origin_ref_uses_earliest_created_without_is_root() { + let current = ProjectLogRef { + root_span_id: "root-1".to_string(), + id: None, + origin: Some(json!({ "id": "later" })), + origin_is_root: false, + origin_created: Some("2026-01-02T00:00:00Z".to_string()), + }; + let candidate = ProjectLogRef { + root_span_id: "root-1".to_string(), + id: None, + origin: Some(json!({ "id": "earlier" })), + origin_is_root: false, + origin_created: Some("2026-01-01T00:00:00Z".to_string()), + }; + + assert!(better_trace_origin_ref(¤t, &candidate)); + } +}