Skip to content

Commit c439855

Browse files
committed
feat: sql to substrait - add support for limit and offset
1 parent b7051bf commit c439855

File tree

3 files changed

+73
-7
lines changed

3 files changed

+73
-7
lines changed

src/substrait/builders/plan.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626

2727
def _merge_extensions(*objs):
2828
return {
29-
"extension_uris": merge_extension_uris(*[b.extension_uris for b in objs]),
30-
"extensions": merge_extension_declarations(*[b.extensions for b in objs]),
29+
"extension_uris": merge_extension_uris(*[b.extension_uris for b in objs if b]),
30+
"extensions": merge_extension_declarations(*[b.extensions for b in objs if b]),
3131
}
3232

3333

@@ -193,13 +193,15 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
193193
bound_plan = plan if isinstance(plan, stp.Plan) else plan(registry)
194194
ns = infer_plan_schema(bound_plan)
195195

196-
bound_offset = resolve_expression(offset, ns, registry)
196+
bound_offset = resolve_expression(offset, ns, registry) if offset else None
197197
bound_count = resolve_expression(count, ns, registry)
198198

199199
rel = stalg.Rel(
200200
fetch=stalg.FetchRel(
201201
input=bound_plan.relations[-1].root.input,
202-
offset_expr=bound_offset.referred_expr[0].expression,
202+
offset_expr=bound_offset.referred_expr[0].expression
203+
if bound_offset
204+
else None,
203205
count_expr=bound_count.referred_expr[0].expression,
204206
)
205207
)

src/substrait/sql/sql_to_substrait.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
project,
1515
filter,
1616
sort,
17+
fetch,
1718
set,
1819
join,
1920
aggregate,
@@ -35,9 +36,7 @@
3536
"Eq": ("functions_comparison.yaml", "equal"),
3637
}
3738

38-
aggregate_function_mapping = {
39-
"SUM": ("functions_arithmetic.yaml", "sum"),
40-
}
39+
aggregate_function_mapping = {"SUM": ("functions_arithmetic.yaml", "sum")}
4140

4241
window_function_mapping = {
4342
"row_number": ("functions_arithmetic.yaml", "row_number"),
@@ -190,6 +189,29 @@ def translate(ast: dict, schema_resolver: SchemaResolver, registry: ExtensionReg
190189
for e in ast["order_by"]["kind"]["Expressions"]
191190
]
192191
relation = sort(relation, expressions)(registry)
192+
193+
if ast["limit_clause"]:
194+
limit_expression = translate_expression(
195+
ast["limit_clause"]["LimitOffset"]["limit"],
196+
schema_resolver=schema_resolver,
197+
registry=registry,
198+
measures=None,
199+
groupings=None,
200+
)
201+
202+
if ast["limit_clause"]["LimitOffset"]["offset"]:
203+
offset_expression = translate_expression(
204+
ast["limit_clause"]["LimitOffset"]["offset"]["value"],
205+
schema_resolver=schema_resolver,
206+
registry=registry,
207+
measures=None,
208+
groupings=None,
209+
)
210+
else:
211+
offset_expression = None
212+
213+
relation = fetch(relation, offset_expression, limit_expression)(registry)
214+
193215
return relation
194216
elif op == "Select":
195217
relation = translate(

tests/sql/test_sql_to_substrait.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,48 @@ def test_order_by(engine: str):
250250
)
251251

252252

253+
@pytest.mark.parametrize(
254+
"engine",
255+
[
256+
pytest.param(
257+
"duckdb",
258+
marks=[
259+
pytest.mark.skipif(
260+
sys.platform.startswith("win"),
261+
reason="duckdb substrait extension not found on windows",
262+
),
263+
pytest.mark.xfail,
264+
],
265+
),
266+
"datafusion",
267+
],
268+
)
269+
def test_select_limit(engine: str):
270+
assert_query("""SELECT store_id FROM stores ORDER BY store_id LIMIT 2""", engine)
271+
272+
273+
@pytest.mark.parametrize(
274+
"engine",
275+
[
276+
pytest.param(
277+
"duckdb",
278+
marks=[
279+
pytest.mark.skipif(
280+
sys.platform.startswith("win"),
281+
reason="duckdb substrait extension not found on windows",
282+
),
283+
pytest.mark.xfail,
284+
],
285+
),
286+
"datafusion",
287+
],
288+
)
289+
def test_select_limit_offset(engine: str):
290+
assert_query(
291+
"""SELECT store_id FROM stores ORDER BY store_id LIMIT 2 OFFSET 2""", engine
292+
)
293+
294+
253295
@pytest.mark.parametrize(
254296
"engine",
255297
[

0 commit comments

Comments
 (0)