Skip to content

Commit ec37415

Browse files
committed
feat: sql to substrait
1 parent b0fa37f commit ec37415

File tree

4 files changed

+1046
-1
lines changed

4 files changed

+1046
-1
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ write_to = "src/substrait/_version.py"
1414
[project.optional-dependencies]
1515
extensions = ["antlr4-python3-runtime", "pyyaml"]
1616
gen_proto = ["protobuf == 3.20.1", "protoletariat >= 2.0.0"]
17-
test = ["pytest >= 7.0.0", "antlr4-python3-runtime", "pyyaml"]
17+
sql = ["sqloxide", "deepdiff"]
18+
test = ["pytest >= 7.0.0", "antlr4-python3-runtime", "pyyaml", "sqloxide", "deepdiff", "duckdb<=1.2.2", "datafusion"]
1819

1920
[tool.pytest.ini_options]
2021
pythonpath = "src"
Lines changed: 317 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,317 @@
1+
import random
2+
import string
3+
from sqloxide import parse_sql
4+
from substrait.builders.extended_expression import (
5+
UnboundExtendedExpression,
6+
column,
7+
scalar_function,
8+
literal,
9+
aggregate_function,
10+
window_function,
11+
)
12+
from substrait.builders.plan import (
13+
read_named_table,
14+
project,
15+
filter,
16+
sort,
17+
set,
18+
join,
19+
aggregate,
20+
)
21+
from substrait.gen.proto import type_pb2 as stt
22+
from substrait.gen.proto import algebra_pb2 as stalg
23+
from substrait.extension_registry import ExtensionRegistry
24+
from typing import Callable
25+
from deepdiff import DeepDiff
26+
27+
SchemaResolver = Callable[[str], stt.NamedStruct]
28+
29+
function_mapping = {
30+
"Plus": ("functions_arithmetic.yaml", "add"),
31+
"Minus": ("functions_arithmetic.yaml", "subtract"),
32+
"Gt": ("functions_comparison.yaml", "gt"),
33+
"GtEq": ("functions_comparison.yaml", "gte"),
34+
"Lt": ("functions_comparison.yaml", "lt"),
35+
"Eq": ("functions_comparison.yaml", "equal"),
36+
}
37+
38+
aggregate_function_mapping = {
39+
"SUM": ("functions_arithmetic.yaml", "sum"),
40+
}
41+
42+
window_function_mapping = {
43+
"row_number": ("functions_arithmetic.yaml", "row_number"),
44+
}
45+
46+
47+
def compare_dicts(dict1, dict2):
48+
diff = DeepDiff(dict1, dict2, exclude_regex_paths=["span"])
49+
return len(diff) == 0
50+
51+
52+
def translate_expression(
53+
ast: dict,
54+
schema_resolver: SchemaResolver,
55+
registry: ExtensionRegistry,
56+
measures: list[UnboundExtendedExpression],
57+
groupings: list[dict],
58+
alias: str = None,
59+
) -> UnboundExtendedExpression:
60+
assert len(ast) == 1
61+
op = list(ast.keys())[0]
62+
63+
if groupings:
64+
# This means we are parsing a projection after a grouping
65+
# Loop through used groupings for an identical ast and return it rather than recalculate
66+
for i, f in enumerate(groupings):
67+
if compare_dicts(ast, f):
68+
return column(i, alias=alias)
69+
70+
ast = ast[op]
71+
72+
if op == "Identifier":
73+
return column(ast["value"], alias=alias)
74+
elif op == "UnnamedExpr" or op == "expr" or op == "Unnamed" or op == "Expr":
75+
return translate_expression(
76+
ast,
77+
schema_resolver=schema_resolver,
78+
registry=registry,
79+
measures=measures,
80+
groupings=groupings,
81+
)
82+
elif op == "ExprWithAlias":
83+
return translate_expression(
84+
ast["expr"],
85+
schema_resolver=schema_resolver,
86+
registry=registry,
87+
measures=measures,
88+
groupings=groupings,
89+
alias=ast["alias"]["value"],
90+
)
91+
elif op == "BinaryOp":
92+
expressions = [
93+
translate_expression(
94+
ast["left"],
95+
schema_resolver=schema_resolver,
96+
registry=registry,
97+
measures=measures,
98+
groupings=groupings,
99+
),
100+
translate_expression(
101+
ast["right"],
102+
schema_resolver=schema_resolver,
103+
registry=registry,
104+
measures=measures,
105+
groupings=groupings,
106+
),
107+
]
108+
func = function_mapping[ast["op"]]
109+
return scalar_function(func[0], func[1], expressions=expressions, alias=alias)
110+
elif op == "Value":
111+
return literal(
112+
int(ast["value"]["Number"][0]), stt.Type(i64=stt.Type.I64()), alias=alias
113+
) # TODO infer type
114+
elif op == "Function":
115+
expressions = [
116+
translate_expression(
117+
e,
118+
schema_resolver=schema_resolver,
119+
registry=registry,
120+
measures=measures,
121+
groupings=groupings,
122+
)
123+
for e in ast["args"]["List"]["args"]
124+
]
125+
name = ast["name"][0]["Identifier"]["value"]
126+
127+
if name in function_mapping:
128+
func = function_mapping[name]
129+
return scalar_function(func[0], func[1], *expressions, alias=alias)
130+
elif name in aggregate_function_mapping:
131+
# All measures need to be extracted out because substrait calculates measures in a separate rel
132+
# We generate a random name for the measure and return a column with that name for the projection to work
133+
# Start by checking if multiple measures are identical and reuse previously generated name
134+
for m in measures:
135+
if compare_dicts(ast, m[1]):
136+
return column(m[2], alias=alias)
137+
138+
func = aggregate_function_mapping[name]
139+
random_name = "".join(
140+
random.choices(string.ascii_uppercase + string.digits, k=5)
141+
) # TODO make this deterministic
142+
aggr = aggregate_function(func[0], func[1], expressions, alias=random_name)
143+
measures.append((aggr, ast, random_name))
144+
return column(random_name, alias=alias)
145+
elif name in window_function_mapping:
146+
func = window_function_mapping[name]
147+
148+
partitions = [
149+
translate_expression(
150+
e,
151+
schema_resolver=schema_resolver,
152+
registry=registry,
153+
measures=measures,
154+
groupings=groupings,
155+
)
156+
for e in ast["over"]["WindowSpec"]["partition_by"]
157+
]
158+
159+
return window_function(
160+
func[0], func[1], expressions, partitions=partitions, alias=alias
161+
)
162+
163+
else:
164+
raise Exception(f"Unknown function {name}")
165+
# elif op == "Wildcard":
166+
# return wildcard()
167+
else:
168+
raise Exception(f"Unknown op {op}")
169+
170+
171+
def translate(ast: dict, schema_resolver: SchemaResolver, registry: ExtensionRegistry):
172+
assert len(ast) == 1
173+
op = list(ast.keys())[0]
174+
ast = ast[op]
175+
176+
if op == "Query":
177+
relation = translate(
178+
ast["body"], schema_resolver=schema_resolver, registry=registry
179+
)
180+
181+
if ast["order_by"]:
182+
expressions = [
183+
translate_expression(
184+
e["expr"],
185+
schema_resolver=schema_resolver,
186+
registry=registry,
187+
measures=None,
188+
groupings=None,
189+
)
190+
for e in ast["order_by"]["kind"]["Expressions"]
191+
]
192+
relation = sort(relation, expressions)(registry)
193+
return relation
194+
elif op == "Select":
195+
relation = translate(
196+
ast["from"][0]["relation"],
197+
schema_resolver=schema_resolver,
198+
registry=registry,
199+
)
200+
201+
if ast["from"][0]["joins"]:
202+
for _join in ast["from"][0]["joins"]:
203+
join_type_mapping = {
204+
"Inner": stalg.JoinRel.JOIN_TYPE_INNER,
205+
"Left": stalg.JoinRel.JOIN_TYPE_LEFT,
206+
"LeftOuter": stalg.JoinRel.JOIN_TYPE_LEFT,
207+
"RightOuter": stalg.JoinRel.JOIN_TYPE_RIGHT,
208+
"Right": stalg.JoinRel.JOIN_TYPE_RIGHT,
209+
}
210+
right = translate(
211+
_join["relation"],
212+
schema_resolver=schema_resolver,
213+
registry=registry,
214+
)
215+
216+
join_type = list(_join["join_operator"].keys())[0]
217+
218+
expression = translate_expression(
219+
_join["join_operator"][join_type]["On"],
220+
schema_resolver=schema_resolver,
221+
registry=registry,
222+
measures=None,
223+
groupings=None,
224+
)
225+
226+
relation = join(
227+
relation, right, expression, join_type_mapping[join_type]
228+
)(registry)
229+
230+
if "selection" in ast and ast["selection"]:
231+
where_expression = translate_expression(
232+
ast["selection"],
233+
schema_resolver=schema_resolver,
234+
registry=registry,
235+
measures=None,
236+
groupings=None,
237+
)
238+
relation = filter(relation, where_expression)(registry)
239+
240+
if ast["group_by"] and ast["group_by"]["Expressions"][0]:
241+
groupings = ast["group_by"]["Expressions"][0]
242+
grouping_expressions = [
243+
translate_expression(
244+
e,
245+
schema_resolver=schema_resolver,
246+
registry=registry,
247+
measures=None,
248+
groupings=None,
249+
)
250+
for e in groupings
251+
]
252+
else:
253+
groupings = []
254+
grouping_expressions = []
255+
256+
measures = []
257+
258+
projection = [
259+
translate_expression(
260+
p,
261+
schema_resolver=schema_resolver,
262+
registry=registry,
263+
measures=measures,
264+
groupings=groupings,
265+
)
266+
for p in ast["projection"]
267+
]
268+
269+
if ast["having"]:
270+
having_predicate = translate_expression(
271+
ast["having"],
272+
schema_resolver=schema_resolver,
273+
registry=registry,
274+
measures=measures,
275+
groupings=[],
276+
)
277+
else:
278+
having_predicate = None
279+
280+
if measures or groupings:
281+
relation = aggregate(
282+
relation, grouping_expressions, [e[0] for e in measures]
283+
)(registry)
284+
285+
if having_predicate:
286+
relation = filter(relation, having_predicate)(registry)
287+
288+
return project(relation, expressions=projection)(registry)
289+
elif op == "Table":
290+
name = ast["name"][0]["Identifier"]["value"]
291+
return read_named_table(name, schema_resolver(name))
292+
elif op == "SetOperation":
293+
# TODO more than 2 inputs to a set operation
294+
left = translate(
295+
ast["left"], schema_resolver=schema_resolver, registry=registry
296+
)
297+
right = translate(
298+
ast["right"], schema_resolver=schema_resolver, registry=registry
299+
)
300+
if ast["op"] == "Union":
301+
set_op = (
302+
stalg.SetRel.SET_OP_UNION_ALL
303+
if ast["set_quantifier"] == "All"
304+
else stalg.SetRel.SET_OP_UNION_DISTINCT
305+
)
306+
else:
307+
raise Exception("")
308+
309+
return set([left, right], set_op)(registry)
310+
else:
311+
raise Exception(f"Unknown op {op}")
312+
313+
314+
def convert(query: str, dialect: str, schema_resolver: SchemaResolver):
315+
ast = parse_sql(sql=query, dialect=dialect)[0]
316+
registry = ExtensionRegistry(load_default_extensions=True)
317+
return translate(ast, schema_resolver=schema_resolver, registry=registry)

0 commit comments

Comments
 (0)