Skip to content

Commit b7fbdfb

Browse files
committed
Port JS Parameters implementation to Python
1 parent 7da14b1 commit b7fbdfb

9 files changed

Lines changed: 906 additions & 20 deletions

File tree

py/src/braintrust/cli/push.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,13 @@ def _collect_prompt_function_defs(
268268
functions.append(p.to_function_definition(if_exists, project_ids))
269269

270270

271+
def _collect_parameters_function_defs(
272+
project_ids: ProjectIdCache, functions: list[dict[str, Any]], if_exists: IfExists
273+
) -> None:
274+
for p in global_.parameters:
275+
functions.append(p.to_function_definition(if_exists, project_ids))
276+
277+
271278
def run(args):
272279
"""Runs the braintrust push subcommand."""
273280
login(
@@ -306,6 +313,8 @@ def run(args):
306313
_collect_function_function_defs(project_ids, functions, bundle_id, args.if_exists)
307314
if len(global_.prompts) > 0:
308315
_collect_prompt_function_defs(project_ids, functions, args.if_exists)
316+
if len(global_.parameters) > 0:
317+
_collect_parameters_function_defs(project_ids, functions, args.if_exists)
309318

310319
if len(functions) > 0:
311320
api_conn().post_json("insert-functions", {"functions": functions})

py/src/braintrust/devserver/server.py

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
from ..framework import EvalAsync, EvalScorer, Evaluator, ExperimentSummary, SSEProgressEvent
2828
from ..generated_types import FunctionId
29-
from ..logger import BraintrustState, bt_iscoroutinefunction
29+
from ..logger import BraintrustState, RemoteEvalParameters, bt_iscoroutinefunction
3030
from ..parameters import parameters_to_json_schema, validate_parameters
3131
from ..span_identifier_v4 import parse_parent
3232
from .auth import AuthorizationMiddleware
@@ -79,6 +79,40 @@ async def index(request: Request) -> PlainTextResponse:
7979
return PlainTextResponse("Hello, world!")
8080

8181

82+
def _ensure_input_field(row: Any) -> Any:
83+
if isinstance(row, dict) and "input" not in row:
84+
return {**row, "input": None}
85+
return row
86+
87+
88+
def _serialize_parameters_container(parameters: Any) -> dict[str, Any]:
89+
if parameters is None:
90+
return {}
91+
92+
if RemoteEvalParameters.is_parameters(parameters):
93+
return {
94+
"type": "braintrust.parameters",
95+
"schema": dict(parameters.schema),
96+
"source": {
97+
"parametersId": parameters.id,
98+
"slug": parameters.slug,
99+
"name": parameters.name,
100+
"projectId": parameters.project_id,
101+
"version": parameters.version,
102+
},
103+
}
104+
105+
schema = parameters_to_json_schema(parameters)
106+
if schema:
107+
return {
108+
"type": "braintrust.staticParameters",
109+
"schema": schema,
110+
"source": None,
111+
}
112+
113+
return {}
114+
115+
82116
async def list_evaluators(request: Request) -> JSONResponse:
83117
# Get the authenticated context
84118
ctx = getattr(request.state, "ctx", None)
@@ -93,7 +127,7 @@ async def list_evaluators(request: Request) -> JSONResponse:
93127
evaluator_list = {}
94128
for name, evaluator in _all_evaluators.items():
95129
evaluator_list[name] = {
96-
"parameters": parameters_to_json_schema(evaluator.parameters) if evaluator.parameters else {},
130+
"parameters": _serialize_parameters_container(evaluator.parameters),
97131
"scores": [{"name": getattr(score, "name", f"score_{i}")} for i, score in enumerate(evaluator.scores)],
98132
}
99133

@@ -130,12 +164,20 @@ async def run_eval(request: Request) -> JSONResponse | StreamingResponse:
130164
if not evaluator:
131165
return JSONResponse({"error": f"Evaluator '{eval_data['name']}' not found"}, status_code=404)
132166

133-
# Get the dataset if data is provided
134-
try:
135-
dataset = await get_dataset(state, eval_data["data"])
136-
except Exception as e:
137-
print(f"Error loading dataset: {e}", file=sys.stderr)
138-
return JSONResponse({"error": f"Failed to load dataset: {str(e)}"}, status_code=400)
167+
# Get the dataset if data is provided, otherwise fall back to the evaluator's own data
168+
dataset = None
169+
raw_data = eval_data.get("data")
170+
if raw_data is not None:
171+
try:
172+
dataset = await get_dataset(state, raw_data)
173+
except Exception as e:
174+
print(f"Error loading dataset from request, falling back to evaluator data: {e}", file=sys.stderr)
175+
176+
if dataset is None:
177+
dataset = evaluator.data
178+
179+
if isinstance(dataset, list):
180+
dataset = [_ensure_input_field(row) for row in dataset]
139181

140182
# Validate parameters if provided
141183
validated_parameters = None

py/src/braintrust/framework.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
Experiment,
3535
ExperimentSummary,
3636
Metadata,
37+
RemoteEvalParameters,
3738
ScoreSummary,
3839
Span,
3940
_ExperimentDatasetEvent,
@@ -42,7 +43,7 @@
4243
stringify_exception,
4344
)
4445
from .logger import init as _init_experiment
45-
from .parameters import EvalParameters
46+
from .parameters import EvalParameters, validate_parameters
4647
from .resource_manager import ResourceManager
4748
from .score import Score, is_score, is_scorer
4849
from .serializable_data_class import SerializableDataClass
@@ -438,7 +439,7 @@ class Evaluator(Generic[Input, Output]):
438439
Whether to summarize the scores of the experiment after it has run.
439440
"""
440441

441-
parameters: EvalParameters | None = None
442+
parameters: EvalParameters | RemoteEvalParameters | None = None
442443
"""
443444
A set of parameters that will be passed to the evaluator.
444445
Can be used to define prompts or other configurable values.
@@ -674,7 +675,7 @@ def _EvalCommon(
674675
summarize_scores: bool,
675676
no_send_logs: bool,
676677
error_score_handler: ErrorScoreHandler | None = None,
677-
parameters: EvalParameters | None = None,
678+
parameters: EvalParameters | RemoteEvalParameters | None = None,
678679
on_start: Callable[[ExperimentSummary], None] | None = None,
679680
stream: Callable[[SSEProgressEvent], None] | None = None,
680681
parent: str | None = None,
@@ -803,7 +804,7 @@ async def EvalAsync(
803804
description: str | None = None,
804805
summarize_scores: bool = True,
805806
no_send_logs: bool = False,
806-
parameters: EvalParameters | None = None,
807+
parameters: EvalParameters | RemoteEvalParameters | None = None,
807808
on_start: Callable[[ExperimentSummary], None] | None = None,
808809
stream: Callable[[SSEProgressEvent], None] | None = None,
809810
parent: str | None = None,
@@ -930,7 +931,7 @@ def Eval(
930931
description: str | None = None,
931932
summarize_scores: bool = True,
932933
no_send_logs: bool = False,
933-
parameters: EvalParameters | None = None,
934+
parameters: EvalParameters | RemoteEvalParameters | None = None,
934935
on_start: Callable[[ExperimentSummary], None] | None = None,
935936
stream: Callable[[SSEProgressEvent], None] | None = None,
936937
parent: str | None = None,
@@ -1390,6 +1391,17 @@ def get_other_fields(s):
13901391
scorer_names = [_scorer_name(scorer, i) for i, scorer in enumerate(scorers)]
13911392
unhandled_scores = scorer_names
13921393

1394+
resolved_parameters: dict[str, Any] | None = None
1395+
if evaluator.parameters is not None:
1396+
if RemoteEvalParameters.is_parameters(evaluator.parameters):
1397+
resolved_parameters = validate_parameters({}, evaluator.parameters)
1398+
elif isinstance(evaluator.parameters, dict):
1399+
resolved_parameters = validate_parameters({}, evaluator.parameters)
1400+
else:
1401+
raise ValueError(
1402+
"Invalid evaluator.parameters. Expected an EvalParameters schema or RemoteEvalParameters."
1403+
)
1404+
13931405
async def run_evaluator_task(datum, trial_index=0):
13941406
if isinstance(datum, dict):
13951407
datum = EvalCase.from_dict(datum)
@@ -1449,7 +1461,7 @@ def report_progress(event: TaskProgressEvent):
14491461
trial_index=trial_index,
14501462
tags=tags,
14511463
report_progress=report_progress,
1452-
parameters=evaluator.parameters,
1464+
parameters=resolved_parameters,
14531465
)
14541466

14551467
# Check if the task takes a hooks argument

py/src/braintrust/framework2.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
SavedFunctionId,
1717
ToolFunctionDefinition,
1818
)
19+
from .parameters import EvalParameters, _pydantic_to_json_schema
1920
from .util import eprint
2021

2122

@@ -34,6 +35,7 @@ class _GlobalState:
3435
def __init__(self):
3536
self.functions: list[CodeFunction] = []
3637
self.prompts: list[CodePrompt] = []
38+
self.parameters: list[CodeParameters] = []
3739

3840

3941
global_ = _GlobalState()
@@ -287,6 +289,161 @@ def create(
287289
return p
288290

289291

292+
def _maybe_serialize_prompt_default(default: Any) -> Any:
293+
as_dict = getattr(default, "as_dict", None)
294+
if callable(as_dict):
295+
return as_dict()
296+
return default
297+
298+
299+
def _pydantic_instance_to_plain(value: Any) -> Any:
300+
if hasattr(value, "model_dump"):
301+
return value.model_dump()
302+
if hasattr(value, "dict"):
303+
return value.dict()
304+
return value
305+
306+
307+
def _is_single_field_value_model(model: Any) -> bool:
308+
fields = getattr(model, "__fields__", None) or getattr(model, "model_fields", {})
309+
return isinstance(fields, dict) and len(fields) == 1 and "value" in fields
310+
311+
312+
def _maybe_set_default_from_pydantic_model(model: Any, schema_obj: dict[str, Any]) -> dict[str, Any]:
313+
if "default" in schema_obj:
314+
return schema_obj
315+
try:
316+
instance = model()
317+
except Exception:
318+
return schema_obj
319+
320+
if _is_single_field_value_model(model) and hasattr(instance, "value"):
321+
return {**schema_obj, "default": _pydantic_instance_to_plain(getattr(instance, "value"))}
322+
323+
return {**schema_obj, "default": _pydantic_instance_to_plain(instance)}
324+
325+
326+
def serialize_eval_parameters_to_parameters_schema(parameters: EvalParameters) -> dict[str, Any]:
327+
properties: dict[str, Any] = {}
328+
required: list[str] = []
329+
330+
for name, schema in parameters.items():
331+
if isinstance(schema, dict) and schema.get("type") == "prompt":
332+
prompt_schema: dict[str, Any] = {"type": "object", "x-bt-type": "prompt"}
333+
334+
description = schema.get("description")
335+
if description is not None:
336+
prompt_schema["description"] = description
337+
338+
default_value = schema.get("default")
339+
if default_value is not None:
340+
prompt_schema["default"] = _maybe_serialize_prompt_default(default_value)
341+
else:
342+
required.append(name)
343+
344+
properties[name] = prompt_schema
345+
continue
346+
347+
if schema is None:
348+
raise ValueError(f"Parameter '{name}' has no schema")
349+
350+
if not (hasattr(schema, "model_json_schema") or hasattr(schema, "schema")):
351+
raise ValueError(
352+
f"Invalid schema for parameter '{name}'. Expected a pydantic model (v1 or v2) or a prompt parameter."
353+
)
354+
355+
schema_obj = _pydantic_to_json_schema(schema)
356+
if _is_single_field_value_model(schema):
357+
value_schema = schema_obj.get("properties", {}).get("value")
358+
if not isinstance(value_schema, dict):
359+
raise ValueError(f"Invalid pydantic schema for parameter '{name}': missing properties.value")
360+
parameter_schema = _maybe_set_default_from_pydantic_model(schema, value_schema)
361+
else:
362+
parameter_schema = _maybe_set_default_from_pydantic_model(schema, schema_obj)
363+
364+
properties[name] = parameter_schema
365+
if "default" not in parameter_schema:
366+
required.append(name)
367+
368+
out: dict[str, Any] = {"type": "object", "properties": properties, "additionalProperties": True}
369+
if required:
370+
out["required"] = required
371+
return out
372+
373+
374+
def get_default_data_from_parameters_schema(schema: dict[str, Any]) -> dict[str, Any]:
375+
properties = schema.get("properties")
376+
if not isinstance(properties, dict):
377+
return {}
378+
379+
return {k: v["default"] for k, v in properties.items() if isinstance(v, dict) and "default" in v}
380+
381+
382+
@dataclasses.dataclass
383+
class CodeParameters:
384+
"""Parameters defined in code, with metadata."""
385+
386+
project: "Project"
387+
name: str
388+
slug: str
389+
description: str | None
390+
schema: EvalParameters
391+
if_exists: IfExists | None
392+
metadata: dict[str, Any] | None = None
393+
394+
def to_function_definition(self, if_exists: IfExists | None, project_ids: ProjectIdCache) -> dict[str, Any]:
395+
schema = serialize_eval_parameters_to_parameters_schema(self.schema)
396+
j: dict[str, Any] = {
397+
"project_id": project_ids.get(self.project),
398+
"name": self.name,
399+
"slug": self.slug,
400+
"function_type": "parameters",
401+
"function_data": {
402+
"type": "parameters",
403+
"data": get_default_data_from_parameters_schema(schema),
404+
"__schema": schema,
405+
},
406+
"if_exists": self.if_exists if self.if_exists is not None else if_exists,
407+
}
408+
if self.description is not None:
409+
j["description"] = self.description
410+
if self.metadata is not None:
411+
j["metadata"] = self.metadata
412+
return j
413+
414+
415+
class ParametersBuilder:
416+
"""Builder to create parameters in Braintrust."""
417+
418+
def __init__(self, project: "Project"):
419+
self.project = project
420+
421+
def create(
422+
self,
423+
*,
424+
name: str,
425+
slug: str | None = None,
426+
description: str | None = None,
427+
schema: EvalParameters,
428+
if_exists: IfExists | None = None,
429+
metadata: dict[str, Any] | None = None,
430+
) -> EvalParameters:
431+
if slug is None or len(slug) == 0:
432+
slug = slugify.slugify(name)
433+
434+
parameters = CodeParameters(
435+
project=self.project,
436+
name=name,
437+
slug=slug,
438+
description=description,
439+
schema=schema,
440+
if_exists=if_exists,
441+
metadata=metadata,
442+
)
443+
self.project.add_parameters(parameters)
444+
return schema
445+
446+
290447
class ScorerBuilder:
291448
"""Builder to create a scorer in Braintrust."""
292449

@@ -461,10 +618,12 @@ def __init__(self, name: str):
461618
self.name = name
462619
self.tools = ToolBuilder(self)
463620
self.prompts = PromptBuilder(self)
621+
self.parameters = ParametersBuilder(self)
464622
self.scorers = ScorerBuilder(self)
465623

466624
self._publishable_code_functions: list[CodeFunction] = []
467625
self._publishable_prompts: list[CodePrompt] = []
626+
self._publishable_parameters: list[CodeParameters] = []
468627

469628
def add_code_function(self, fn: CodeFunction):
470629
self._publishable_code_functions.append(fn)
@@ -476,6 +635,11 @@ def add_prompt(self, prompt: CodePrompt):
476635
if _is_lazy_load():
477636
global_.prompts.append(prompt)
478637

638+
def add_parameters(self, parameters: CodeParameters):
639+
self._publishable_parameters.append(parameters)
640+
if _is_lazy_load():
641+
global_.parameters.append(parameters)
642+
479643
def publish(self):
480644
if _is_lazy_load():
481645
eprint(f"{bcolors.WARNING}publish() is a no-op when running `braintrust push`.{bcolors.ENDC}")

0 commit comments

Comments
 (0)