@@ -152,6 +152,17 @@ def infer_rel_schema(rel: stalg.Rel) -> stt.Type.Struct:
152152 (common , struct ) = (rel .filter .common , infer_rel_schema (rel .filter .input ))
153153 elif rel_type == "fetch" :
154154 (common , struct ) = (rel .fetch .common , infer_rel_schema (rel .fetch .input ))
155+ elif rel_type == "aggregate" :
156+ parent_schema = infer_rel_schema (rel .aggregate .input )
157+ grouping_types = [infer_expression_type (g , parent_schema ) for g in rel .aggregate .grouping_expressions ]
158+ measure_types = [m .measure .output_type for m in rel .aggregate .measures ]
159+
160+ raw_schema = stt .Type .Struct (
161+ types = grouping_types + measure_types ,
162+ nullability = parent_schema .nullability ,
163+ )
164+
165+ (common , struct ) = (rel .aggregate .common , raw_schema )
155166 elif rel_type == "sort" :
156167 (common , struct ) = (rel .sort .common , infer_rel_schema (rel .sort .input ))
157168 elif rel_type == "project" :
@@ -165,6 +176,18 @@ def infer_rel_schema(rel: stalg.Rel) -> stt.Type.Struct:
165176 )
166177
167178 (common , struct ) = (rel .project .common , raw_schema )
179+ elif rel_type == "set" :
180+ (common , struct ) = (rel .fetch .common , infer_rel_schema (rel .set .inputs [0 ]))
181+ elif rel_type == "cross" :
182+ left_schema = infer_rel_schema (rel .cross .left )
183+ right_schema = infer_rel_schema (rel .cross .right )
184+
185+ raw_schema = stt .Type .Struct (
186+ types = left_schema .types + right_schema .types ,
187+ nullability = stt .Type .Nullability .NULLABILITY_REQUIRED ,
188+ )
189+
190+ (common , struct ) = (rel .cross .common , raw_schema )
168191 else :
169192 raise Exception (f"Unhandled rel_type { rel_type } " )
170193
0 commit comments