Skip to content

Commit 2ccd15c

Browse files
committed
feat(builders): add ConsistentPartitionWindowRel builder
Add builder, type inference, display, and comprehensive tests for ConsistentPartitionWindowRel, enabling window function plans with shared partitioning and ordering.
1 parent 54bcf72 commit 2ccd15c

File tree

4 files changed

+498
-0
lines changed

4 files changed

+498
-0
lines changed

src/substrait/builders/plan.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,3 +498,97 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
498498
)
499499

500500
return resolve
501+
502+
503+
def consistent_partition_window(
504+
plan: PlanOrUnbound,
505+
window_functions: Iterable[ExtendedExpressionOrUnbound],
506+
partition_expressions: Iterable[ExtendedExpressionOrUnbound] = (),
507+
sorts: Iterable[
508+
Union[
509+
ExtendedExpressionOrUnbound,
510+
tuple[ExtendedExpressionOrUnbound, stalg.SortField.SortDirection.ValueType],
511+
]
512+
] = (),
513+
extension: Optional[AdvancedExtension] = None,
514+
) -> UnboundPlan:
515+
def resolve(registry: ExtensionRegistry) -> stp.Plan:
516+
bound_plan = plan if isinstance(plan, stp.Plan) else plan(registry)
517+
ns = infer_plan_schema(bound_plan)
518+
519+
bound_partitions = [
520+
resolve_expression(e, ns, registry) for e in partition_expressions
521+
]
522+
523+
bound_sorts = [
524+
(e, stalg.SortField.SORT_DIRECTION_ASC_NULLS_LAST)
525+
if not isinstance(e, tuple)
526+
else e
527+
for e in sorts
528+
]
529+
bound_sorts = [
530+
(resolve_expression(e[0], ns, registry), e[1]) for e in bound_sorts
531+
]
532+
533+
bound_window_fns = [
534+
resolve_expression(e, ns, registry) for e in window_functions
535+
]
536+
537+
window_rel_functions = []
538+
for wf_ee in bound_window_fns:
539+
wf_expr = wf_ee.referred_expr[0].expression.window_function
540+
window_rel_functions.append(
541+
stalg.ConsistentPartitionWindowRel.WindowRelFunction(
542+
function_reference=wf_expr.function_reference,
543+
arguments=list(wf_expr.arguments),
544+
options=list(wf_expr.options),
545+
output_type=wf_expr.output_type,
546+
phase=wf_expr.phase,
547+
invocation=wf_expr.invocation,
548+
lower_bound=wf_expr.lower_bound
549+
if wf_expr.HasField("lower_bound")
550+
else None,
551+
upper_bound=wf_expr.upper_bound
552+
if wf_expr.HasField("upper_bound")
553+
else None,
554+
bounds_type=wf_expr.bounds_type,
555+
)
556+
)
557+
558+
names = list(bound_plan.relations[-1].root.names) + [
559+
wf_ee.referred_expr[0].output_names[0]
560+
if wf_ee.referred_expr[0].output_names
561+
else f"window_{i}"
562+
for i, wf_ee in enumerate(bound_window_fns)
563+
]
564+
565+
rel = stalg.Rel(
566+
window=stalg.ConsistentPartitionWindowRel(
567+
input=bound_plan.relations[-1].root.input,
568+
window_functions=window_rel_functions,
569+
partition_expressions=[
570+
e.referred_expr[0].expression for e in bound_partitions
571+
],
572+
sorts=[
573+
stalg.SortField(
574+
expr=e[0].referred_expr[0].expression,
575+
direction=e[1],
576+
)
577+
for e in bound_sorts
578+
],
579+
advanced_extension=extension,
580+
)
581+
)
582+
583+
return stp.Plan(
584+
version=default_version,
585+
relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=names))],
586+
**_merge_extensions(
587+
bound_plan,
588+
*bound_partitions,
589+
*[e[0] for e in bound_sorts],
590+
*bound_window_fns,
591+
),
592+
)
593+
594+
return resolve

src/substrait/type_inference.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,14 @@ def infer_rel_schema(rel: stalg.Rel) -> stt.Type.Struct:
342342
raise Exception(f"Unhandled join_type {rel.join.type}")
343343

344344
(common, struct) = (rel.join.common, raw_schema)
345+
elif rel_type == "window":
346+
parent_schema = infer_rel_schema(rel.window.input)
347+
window_output_types = [wf.output_type for wf in rel.window.window_functions]
348+
raw_schema = stt.Type.Struct(
349+
types=list(parent_schema.types) + window_output_types,
350+
nullability=parent_schema.nullability,
351+
)
352+
(common, struct) = (rel.window.common, raw_schema)
345353
else:
346354
raise Exception(f"Unhandled rel_type {rel_type}")
347355

src/substrait/utils/display.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,8 @@ def _stream_rel(self, rel: stalg.Rel, stream, depth: int):
171171
self._stream_extension_single_rel(rel.extension_single, stream, depth)
172172
elif rel.HasField("extension_multi"):
173173
self._stream_extension_multi_rel(rel.extension_multi, stream, depth)
174+
elif rel.HasField("window"):
175+
self._stream_window_rel(rel.window, stream, depth)
174176
else:
175177
stream.write(f"{indent}<unknown_relation>\n")
176178

@@ -401,6 +403,43 @@ def _stream_extension_multi_rel(
401403
f"{self._get_indent_with_arrow(depth + 2)}<unpackable_detail>\n"
402404
)
403405

406+
def _stream_window_rel(
407+
self, window: stalg.ConsistentPartitionWindowRel, stream, depth: int
408+
):
409+
"""Print a consistent partition window relation concisely"""
410+
indent = " " * (depth * self.indent_size)
411+
412+
stream.write(
413+
f"{indent}{self._color('window', Colors.MAGENTA)}: "
414+
f"{self._color(str(len(window.window_functions)), Colors.YELLOW)} functions\n"
415+
)
416+
stream.write(
417+
f"{self._get_indent_with_arrow(depth + 1)}{self._color('input:', Colors.BLUE)}\n"
418+
)
419+
self._stream_rel(window.input, stream, depth + 1)
420+
421+
if window.partition_expressions:
422+
stream.write(
423+
f"{self._get_indent_with_arrow(depth + 1)}"
424+
f"{self._color('partitions:', Colors.BLUE)} "
425+
f"{self._color(str(len(window.partition_expressions)), Colors.YELLOW)}\n"
426+
)
427+
428+
if window.sorts:
429+
stream.write(
430+
f"{self._get_indent_with_arrow(depth + 1)}"
431+
f"{self._color('sorts:', Colors.BLUE)} "
432+
f"{self._color(str(len(window.sorts)), Colors.YELLOW)}\n"
433+
)
434+
435+
for i, wf in enumerate(window.window_functions):
436+
stream.write(
437+
f"{self._get_indent_with_arrow(depth + 1)}"
438+
f"{self._color('window_fn', Colors.BLUE)}"
439+
f"[{self._color(str(i), Colors.CYAN)}]: "
440+
f"func_ref={wf.function_reference}\n"
441+
)
442+
404443
def _stream_expression(self, expression: stalg.Expression, stream, depth: int):
405444
"""Print an expression concisely"""
406445
indent = " " * (depth * self.indent_size)

0 commit comments

Comments
 (0)