Skip to content

Commit 937d39c

Browse files
committed
Update ffi examples
1 parent cb7a755 commit 937d39c

File tree

3 files changed

+19
-19
lines changed

3 files changed

+19
-19
lines changed

examples/datafusion-ffi-example/python/tests/_test_aggregate_udf.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,16 @@ def test_ffi_aggregate_register():
4545

4646
result = ctx.sql("select my_custom_sum(a) from test_table group by b").collect()
4747

48-
assert result
48+
assert len(result) == 2
4949
assert result[0].num_columns == 1
5050

51-
# Normalizing table registration in _normalize_table_provider feeds the Rust layer
52-
# an actual TableProvider, so collect() emits the grouped rows in a single record batch
53-
# instead of two separate batches.
54-
aggregates = pa.concat_arrays([batch.column(0) for batch in result])
51+
result = [r.column(0) for r in result]
52+
expected = [
53+
pa.array([3], type=pa.int64()),
54+
pa.array([3], type=pa.int64()),
55+
]
5556

56-
assert len(aggregates) == 2
57-
assert aggregates.to_pylist() == [3, 3]
57+
assert result == expected
5858

5959

6060
def test_ffi_aggregate_call_directly():
@@ -65,13 +65,13 @@ def test_ffi_aggregate_call_directly():
6565
ctx.table("test_table").aggregate([col("b")], [my_udaf(col("a"))]).collect()
6666
)
6767

68-
# Normalizing table registration in _normalize_table_provider feeds the Rust layer
69-
# an actual TableProvider, so collect() emits the grouped rows in a single record batch
70-
# instead of two separate batches.
71-
assert result
68+
assert len(result) == 2
7269
assert result[0].num_columns == 2
7370

74-
aggregates = pa.concat_arrays([batch.column(1) for batch in result])
71+
result = [r.column(1) for r in result]
72+
expected = [
73+
pa.array([3], type=pa.int64()),
74+
pa.array([3], type=pa.int64()),
75+
]
7576

76-
assert len(aggregates) == 2
77-
assert aggregates.to_pylist() == [3, 3]
77+
assert result == expected

examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ def test_catalog_provider():
3636

3737
my_catalog_schemas = my_catalog.names()
3838
assert expected_schema_name in my_catalog_schemas
39-
my_database = my_catalog.database(expected_schema_name)
40-
assert expected_table_name in my_database.names()
41-
my_table = my_database.table(expected_table_name)
39+
my_schema = my_catalog.schema(expected_schema_name)
40+
assert expected_table_name in my_schema.names()
41+
my_table = my_schema.table(expected_table_name)
4242
assert expected_table_columns == my_table.schema.names
4343

4444
result = ctx.table(

examples/datafusion-ffi-example/python/tests/_test_table_provider.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@
1818
from __future__ import annotations
1919

2020
import pyarrow as pa
21-
from datafusion import SessionContext, Table
21+
from datafusion import SessionContext
2222
from datafusion_ffi_example import MyTableProvider
2323

2424

2525
def test_table_loading():
2626
ctx = SessionContext()
2727
table = MyTableProvider(3, 2, 4)
28-
ctx.register_table("t", Table.from_capsule(table.__datafusion_table_provider__()))
28+
ctx.register_table("t", table)
2929
result = ctx.table("t").collect()
3030

3131
assert len(result) == 4

0 commit comments

Comments
 (0)