Skip to content

Commit 7ea05b8

Browse files
committed
feat: add literal and scalar function expression builders
1 parent 444da04 commit 7ea05b8

File tree

6 files changed

+311
-7
lines changed

6 files changed

+311
-7
lines changed

src/substrait/extended_expression.py

Lines changed: 108 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,53 @@
22
import substrait.gen.proto.algebra_pb2 as stalg
33
import substrait.gen.proto.type_pb2 as stp
44
import substrait.gen.proto.extended_expression_pb2 as stee
5-
from substrait.utils import type_num_names
5+
import substrait.gen.proto.extensions.extensions_pb2 as ste
6+
from substrait.function_registry import FunctionRegistry
7+
from substrait.utils import type_num_names, merge_extension_uris, merge_extension_declarations
8+
from substrait.type_inference import infer_extended_expression_schema
9+
from typing import Callable, Any
610

11+
UnboundExpression = Callable[[stp.NamedStruct, FunctionRegistry], stee.ExtendedExpression]
12+
13+
def literal(value: Any, type: stp.Type, alias: str = None) -> UnboundExpression:
14+
def resolve(base_schema: stp.NamedStruct, registry: FunctionRegistry) -> stee.ExtendedExpression:
15+
kind = type.WhichOneof('kind')
16+
17+
if kind == "bool":
18+
literal = stalg.Expression.Literal(boolean=value, nullable=type.bool.nullability == stp.Type.NULLABILITY_NULLABLE)
19+
elif kind == "i8":
20+
literal = stalg.Expression.Literal(i8=value, nullable=type.i8.nullability == stp.Type.NULLABILITY_NULLABLE)
21+
elif kind == "i16":
22+
literal = stalg.Expression.Literal(i16=value, nullable=type.i16.nullability == stp.Type.NULLABILITY_NULLABLE)
23+
elif kind == "i32":
24+
literal = stalg.Expression.Literal(i32=value, nullable=type.i32.nullability == stp.Type.NULLABILITY_NULLABLE)
25+
elif kind == "i64":
26+
literal = stalg.Expression.Literal(i64=value, nullable=type.i64.nullability == stp.Type.NULLABILITY_NULLABLE)
27+
elif kind == "fp32":
28+
literal = stalg.Expression.Literal(fp32=value, nullable=type.fp32.nullability == stp.Type.NULLABILITY_NULLABLE)
29+
elif kind == "fp64":
30+
literal = stalg.Expression.Literal(fp64=value, nullable=type.fp64.nullability == stp.Type.NULLABILITY_NULLABLE)
31+
elif kind == "string":
32+
literal = stalg.Expression.Literal(string=value, nullable=type.string.nullability == stp.Type.NULLABILITY_NULLABLE)
33+
else:
34+
raise Exception(f"Unknown literal type - {type}")
35+
36+
return stee.ExtendedExpression(
37+
referred_expr=[
38+
stee.ExpressionReference(
39+
expression=stalg.Expression(
40+
literal=literal
41+
),
42+
output_names=[alias if alias else f'literal_{kind}'],
43+
)
44+
],
45+
base_schema=base_schema,
46+
)
47+
48+
return resolve
749

850
def column(name: str):
9-
def resolve(base_schema: stp.NamedStruct) -> stee.ExtendedExpression:
51+
def resolve(base_schema: stp.NamedStruct, registry: FunctionRegistry) -> stee.ExtendedExpression:
1052
column_index = list(base_schema.names).index(name)
1153
lengths = [type_num_names(t) for t in base_schema.struct.types]
1254
flat_indices = [0] + list(itertools.accumulate(lengths))[:-1]
@@ -39,3 +81,67 @@ def resolve(base_schema: stp.NamedStruct) -> stee.ExtendedExpression:
3981
)
4082

4183
return resolve
84+
85+
def scalar_function(uri: str, function: str, *expressions: UnboundExpression, alias: str = None):
86+
def resolve(base_schema: stp.NamedStruct, registry: FunctionRegistry) -> stee.ExtendedExpression:
87+
bound_expressions: list[stee.ExtendedExpression] = [e(base_schema, registry) for e in expressions]
88+
89+
expression_schemas = [infer_extended_expression_schema(b) for b in bound_expressions]
90+
91+
signature = [typ for es in expression_schemas for typ in es.types]
92+
93+
func = registry.lookup_function(uri, function, signature)
94+
95+
if not func:
96+
raise Exception('')
97+
98+
func_extension_uris = [
99+
ste.SimpleExtensionURI(
100+
extension_uri_anchor=registry.lookup_uri(uri),
101+
uri=uri
102+
)
103+
]
104+
105+
func_extensions = [
106+
ste.SimpleExtensionDeclaration(
107+
extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction(
108+
extension_uri_reference=registry.lookup_uri(uri),
109+
function_anchor=func[0].anchor,
110+
name=function
111+
)
112+
)
113+
]
114+
115+
extension_uris = merge_extension_uris(
116+
func_extension_uris,
117+
*[b.extension_uris for b in bound_expressions]
118+
)
119+
120+
extensions = merge_extension_declarations(
121+
func_extensions,
122+
*[b.extensions for b in bound_expressions]
123+
)
124+
125+
return stee.ExtendedExpression(
126+
referred_expr=[
127+
stee.ExpressionReference(
128+
expression=stalg.Expression(
129+
scalar_function=stalg.Expression.ScalarFunction(
130+
function_reference=func[0].anchor,
131+
arguments=[
132+
stalg.FunctionArgument(
133+
value=e.referred_expr[0].expression
134+
) for e in bound_expressions
135+
],
136+
output_type=func[1]
137+
)
138+
),
139+
output_names=[alias if alias else 'scalar_function'],
140+
)
141+
],
142+
base_schema=base_schema,
143+
extension_uris=extension_uris,
144+
extensions=extensions
145+
)
146+
147+
return resolve

src/substrait/function_registry.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from substrait.gen.proto.parameterized_types_pb2 import ParameterizedType
21
from substrait.gen.proto.type_pb2 import Type
32
from importlib.resources import files as importlib_files
43
import itertools
@@ -227,6 +226,9 @@ def satisfies_signature(self, signature: tuple) -> Optional[str]:
227226

228227
class FunctionRegistry:
229228
def __init__(self, load_default_extensions=True) -> None:
229+
self._uri_mapping: dict = defaultdict(dict)
230+
self._uri_id_generator = itertools.count(1)
231+
230232
self._function_mapping: dict = defaultdict(dict)
231233
self._id_generator = itertools.count(1)
232234

@@ -252,6 +254,8 @@ def register_extension_yaml(
252254
self.register_extension_dict(extension_definitions, uri)
253255

254256
def register_extension_dict(self, definitions: dict, uri: str) -> None:
257+
self._uri_mapping[uri] = next(self._uri_id_generator)
258+
255259
for named_functions in definitions.values():
256260
for function in named_functions:
257261
for impl in function.get("impls", []):
@@ -285,3 +289,7 @@ def lookup_function(
285289
return (f, rtn)
286290

287291
return None
292+
293+
def lookup_uri(self, uri: str) -> Optional[int]:
294+
uri = self._uri_aliases.get(uri, uri)
295+
return self._uri_mapping.get(uri, None)

src/substrait/type_inference.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import substrait.gen.proto.algebra_pb2 as stalg
2+
import substrait.gen.proto.extended_expression_pb2 as stee
23
import substrait.gen.proto.type_pb2 as stt
34

45

@@ -220,6 +221,17 @@ def infer_expression_type(
220221
raise Exception(f"Unknown rex_type {rex_type}")
221222

222223

224+
def infer_extended_expression_schema(ee: stee.ExtendedExpression) -> stt.Type.Struct:
225+
exprs = [e for e in ee.referred_expr]
226+
227+
types = [infer_expression_type(e.expression, ee.base_schema.struct) for e in exprs]
228+
229+
return stt.Type.Struct(
230+
types=types,
231+
nullability=stt.Type.NULLABILITY_REQUIRED,
232+
)
233+
234+
223235
def infer_rel_schema(rel: stalg.Rel) -> stt.Type.Struct:
224236
rel_type = rel.WhichOneof("rel_type")
225237

src/substrait/utils.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import substrait.gen.proto.type_pb2 as stp
2-
2+
import substrait.gen.proto.extensions.extensions_pb2 as ste
3+
from typing import Iterable
34

45
def type_num_names(typ: stp.Type):
56
kind = typ.WhichOneof("kind")
@@ -12,3 +13,32 @@ def type_num_names(typ: stp.Type):
1213
return type_num_names(typ.map.key) + type_num_names(typ.map.value)
1314
else:
1415
return 1
16+
17+
def merge_extension_uris(*extension_uris: Iterable[ste.SimpleExtensionURI]):
18+
seen_uris = set()
19+
ret = []
20+
21+
for uris in extension_uris:
22+
for uri in uris:
23+
if uri.uri not in seen_uris:
24+
seen_uris.add(uri.uri)
25+
ret.append(uri)
26+
27+
return ret
28+
29+
def merge_extension_declarations(*extension_declarations: Iterable[ste.SimpleExtensionDeclaration]):
30+
seen_extension_functions = set()
31+
ret = []
32+
33+
for declarations in extension_declarations:
34+
for declaration in declarations:
35+
if declaration.WhichOneof('mapping_type') == 'extension_function':
36+
ident = (declaration.extension_function.extension_uri_reference, declaration.extension_function.name)
37+
if ident not in seen_extension_functions:
38+
seen_extension_functions.add(ident)
39+
ret.append(declaration)
40+
else:
41+
raise Exception('')
42+
43+
return ret
44+
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343

4444

4545
def test_column_no_nesting():
46-
assert column("description")(named_struct) == stee.ExtendedExpression(
46+
assert column("description")(named_struct, None) == stee.ExtendedExpression(
4747
referred_expr=[
4848
stee.ExpressionReference(
4949
expression=stalg.Expression(
@@ -64,7 +64,7 @@ def test_column_no_nesting():
6464

6565

6666
def test_column_nesting():
67-
assert column("order_total")(nested_named_struct) == stee.ExtendedExpression(
67+
assert column("order_total")(nested_named_struct, None) == stee.ExtendedExpression(
6868
referred_expr=[
6969
stee.ExpressionReference(
7070
expression=stalg.Expression(
@@ -85,7 +85,7 @@ def test_column_nesting():
8585

8686

8787
def test_column_nested_struct():
88-
assert column("shop_details")(nested_named_struct) == stee.ExtendedExpression(
88+
assert column("shop_details")(nested_named_struct, None) == stee.ExtendedExpression(
8989
referred_expr=[
9090
stee.ExpressionReference(
9191
expression=stalg.Expression(
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import yaml
2+
3+
import substrait.gen.proto.algebra_pb2 as stalg
4+
import substrait.gen.proto.type_pb2 as stt
5+
import substrait.gen.proto.extended_expression_pb2 as stee
6+
import substrait.gen.proto.extensions.extensions_pb2 as ste
7+
from substrait.extended_expression import scalar_function, literal
8+
from substrait.function_registry import FunctionRegistry
9+
10+
struct = stt.Type.Struct(
11+
types=[
12+
stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)),
13+
stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)),
14+
stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)),
15+
]
16+
)
17+
18+
named_struct = stt.NamedStruct(
19+
names=["order_id", "description", "order_total"], struct=struct
20+
)
21+
22+
content = """%YAML 1.2
23+
---
24+
scalar_functions:
25+
- name: "test_func"
26+
description: ""
27+
impls:
28+
- args:
29+
- value: i8
30+
variadic:
31+
min: 2
32+
return: i8
33+
- name: "is_positive"
34+
description: ""
35+
impls:
36+
- args:
37+
- value: i8
38+
return: boolean
39+
"""
40+
41+
42+
registry = FunctionRegistry(load_default_extensions=False)
43+
registry.register_extension_dict(yaml.safe_load(content), uri="test_uri")
44+
45+
def test_sclar_add():
46+
e = scalar_function('test_uri', 'test_func',
47+
literal(10, type=stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED))),
48+
literal(20, type=stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED))),
49+
alias='sum',
50+
)(named_struct, registry)
51+
52+
expected = stee.ExtendedExpression(
53+
extension_uris=[
54+
ste.SimpleExtensionURI(
55+
extension_uri_anchor=1,
56+
uri='test_uri'
57+
)
58+
],
59+
extensions=[
60+
ste.SimpleExtensionDeclaration(
61+
extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction(
62+
extension_uri_reference=1,
63+
function_anchor=1,
64+
name='test_func'
65+
)
66+
)
67+
],
68+
referred_expr=[
69+
stee.ExpressionReference(
70+
expression=stalg.Expression(
71+
scalar_function=stalg.Expression.ScalarFunction(
72+
function_reference=1,
73+
arguments=[
74+
stalg.FunctionArgument(value=stalg.Expression(literal=stalg.Expression.Literal(i8=10, nullable=False))),
75+
stalg.FunctionArgument(value=stalg.Expression(literal=stalg.Expression.Literal(i8=20, nullable=False)))
76+
],
77+
output_type=stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED))
78+
)
79+
),
80+
output_names=["sum"],
81+
)
82+
],
83+
base_schema=named_struct,
84+
)
85+
86+
assert e == expected
87+
88+
89+
def test_nested_scalar_calls():
90+
e = scalar_function('test_uri', 'is_positive',
91+
scalar_function('test_uri', 'test_func',
92+
literal(10, type=stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED))),
93+
literal(20, type=stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED)))),
94+
alias='positive'
95+
)(named_struct, registry)
96+
97+
expected = stee.ExtendedExpression(
98+
extension_uris=[
99+
ste.SimpleExtensionURI(
100+
extension_uri_anchor=1,
101+
uri='test_uri'
102+
)
103+
],
104+
extensions=[
105+
ste.SimpleExtensionDeclaration(
106+
extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction(
107+
extension_uri_reference=1,
108+
function_anchor=2,
109+
name='is_positive'
110+
)
111+
),
112+
ste.SimpleExtensionDeclaration(
113+
extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction(
114+
extension_uri_reference=1,
115+
function_anchor=1,
116+
name='test_func'
117+
)
118+
)
119+
],
120+
referred_expr=[
121+
stee.ExpressionReference(
122+
expression=stalg.Expression(
123+
scalar_function=stalg.Expression.ScalarFunction(
124+
function_reference=2,
125+
arguments=[
126+
stalg.FunctionArgument(
127+
value=stalg.Expression(
128+
scalar_function=stalg.Expression.ScalarFunction(
129+
function_reference=1,
130+
arguments=[
131+
stalg.FunctionArgument(value=stalg.Expression(literal=stalg.Expression.Literal(i8=10, nullable=False))),
132+
stalg.FunctionArgument(value=stalg.Expression(literal=stalg.Expression.Literal(i8=20, nullable=False)))
133+
],
134+
output_type=stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED))
135+
)
136+
)
137+
)
138+
],
139+
output_type=stt.Type(bool=stt.Type.Boolean(nullability=stt.Type.NULLABILITY_REQUIRED))
140+
)
141+
),
142+
output_names=["positive"],
143+
)
144+
],
145+
base_schema=named_struct,
146+
)
147+
148+
assert e == expected

0 commit comments

Comments
 (0)