Skip to content

Commit 3f789f8

Browse files
committed
feat: add inferene for set, cross and aggregate rels
1 parent 277df48 commit 3f789f8

File tree

2 files changed

+68
-0
lines changed

2 files changed

+68
-0
lines changed

src/substrait/type_inference.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,17 @@ def infer_rel_schema(rel: stalg.Rel) -> stt.Type.Struct:
152152
(common, struct) = (rel.filter.common, infer_rel_schema(rel.filter.input))
153153
elif rel_type == "fetch":
154154
(common, struct) = (rel.fetch.common, infer_rel_schema(rel.fetch.input))
155+
elif rel_type == "aggregate":
156+
parent_schema = infer_rel_schema(rel.aggregate.input)
157+
grouping_types = [infer_expression_type(g, parent_schema) for g in rel.aggregate.grouping_expressions]
158+
measure_types = [m.measure.output_type for m in rel.aggregate.measures]
159+
160+
raw_schema = stt.Type.Struct(
161+
types=grouping_types + measure_types,
162+
nullability=parent_schema.nullability,
163+
)
164+
165+
(common, struct) = (rel.aggregate.common, raw_schema)
155166
elif rel_type == "sort":
156167
(common, struct) = (rel.sort.common, infer_rel_schema(rel.sort.input))
157168
elif rel_type == "project":
@@ -165,6 +176,18 @@ def infer_rel_schema(rel: stalg.Rel) -> stt.Type.Struct:
165176
)
166177

167178
(common, struct) = (rel.project.common, raw_schema)
179+
elif rel_type == "set":
180+
(common, struct) = (rel.fetch.common, infer_rel_schema(rel.set.inputs[0]))
181+
elif rel_type == "cross":
182+
left_schema = infer_rel_schema(rel.cross.left)
183+
right_schema = infer_rel_schema(rel.cross.right)
184+
185+
raw_schema = stt.Type.Struct(
186+
types=left_schema.types + right_schema.types,
187+
nullability=stt.Type.Nullability.NULLABILITY_REQUIRED,
188+
)
189+
190+
(common, struct) = (rel.cross.common, raw_schema)
168191
else:
169192
raise Exception(f"Unhandled rel_type {rel_type}")
170193

tests/test_type_inference.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,48 @@ def test_inference_project_scalar_function():
9797
)
9898

9999
assert infer_rel_schema(rel) == expected
100+
101+
def test_inference_aggregate():
102+
rel = stalg.Rel(
103+
aggregate=stalg.AggregateRel(
104+
input=read_rel,
105+
grouping_expressions=[
106+
stalg.Expression(
107+
selection=stalg.Expression.FieldReference(
108+
root_reference=stalg.Expression.FieldReference.RootReference(),
109+
direct_reference=stalg.Expression.ReferenceSegment(
110+
struct_field=stalg.Expression.ReferenceSegment.StructField(
111+
field=1,
112+
),
113+
),
114+
)
115+
)
116+
],
117+
groupings=[
118+
stalg.AggregateRel.Grouping(
119+
expression_references=[0]
120+
)
121+
],
122+
measures=[
123+
stalg.AggregateRel.Measure(
124+
measure=stalg.AggregateFunction(
125+
function_reference=0,
126+
output_type=stt.Type(
127+
bool=stt.Type.Boolean(
128+
nullability=stt.Type.NULLABILITY_REQUIRED
129+
)
130+
),
131+
)
132+
)
133+
],
134+
)
135+
)
136+
137+
expected = stt.Type.Struct(
138+
types=[
139+
stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)),
140+
stt.Type(bool=stt.Type.Boolean(nullability=stt.Type.NULLABILITY_REQUIRED)),
141+
]
142+
)
143+
144+
assert infer_rel_schema(rel) == expected

0 commit comments

Comments
 (0)