Skip to content

Commit 7bc303d

Browse files
committed
Refactor aggregate tests to simplify result assertions and improve readability
1 parent 93f0a31 commit 7bc303d

File tree

1 file changed

+8
-14
lines changed

1 file changed

+8
-14
lines changed

examples/datafusion-ffi-example/python/tests/_test_aggregate_udf.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,13 @@ def test_ffi_aggregate_register():
4545

4646
result = ctx.sql("select my_custom_sum(a) from test_table group by b").collect()
4747

48-
assert len(result) == 2
48+
assert result
4949
assert result[0].num_columns == 1
5050

51-
result = [r.column(0) for r in result]
52-
expected = [
53-
pa.array([3], type=pa.int64()),
54-
pa.array([3], type=pa.int64()),
55-
]
51+
aggregates = pa.concat_arrays([batch.column(0) for batch in result])
5652

57-
assert result == expected
53+
assert len(aggregates) == 2
54+
assert aggregates.to_pylist() == [3, 3]
5855

5956

6057
def test_ffi_aggregate_call_directly():
@@ -65,13 +62,10 @@ def test_ffi_aggregate_call_directly():
6562
ctx.table("test_table").aggregate([col("b")], [my_udaf(col("a"))]).collect()
6663
)
6764

68-
assert len(result) == 2
65+
assert result
6966
assert result[0].num_columns == 2
7067

71-
result = [r.column(1) for r in result]
72-
expected = [
73-
pa.array([3], type=pa.int64()),
74-
pa.array([3], type=pa.int64()),
75-
]
68+
aggregates = pa.concat_arrays([batch.column(1) for batch in result])
7669

77-
assert result == expected
70+
assert len(aggregates) == 2
71+
assert aggregates.to_pylist() == [3, 3]

0 commit comments

Comments
 (0)