Skip to content

Commit 1df6db2

Browse files
fix: Inconsistent schemas when converting to pyarrow (#1315)
* Fix inconsistent schemas when converting to pyarrow * Add extra tests * Change deprecated type --------- Co-authored-by: Tim Saucer <[email protected]>
1 parent 3a4ae6d commit 1df6db2

File tree

2 files changed

+56
-2
lines changed

2 files changed

+56
-2
lines changed

python/tests/test_dataframe.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1898,6 +1898,53 @@ def test_to_arrow_table(df):
18981898
assert set(pyarrow_table.column_names) == {"a", "b", "c"}
18991899

19001900

1901+
def test_parquet_non_null_column_to_pyarrow(ctx, tmp_path):
1902+
path = tmp_path.joinpath("t.parquet")
1903+
1904+
ctx.sql("create table t_(a int not null)").collect()
1905+
ctx.sql("insert into t_ values (1), (2), (3)").collect()
1906+
ctx.sql(f"copy (select * from t_) to '{path}'").collect()
1907+
1908+
ctx.register_parquet("t", path)
1909+
pyarrow_table = ctx.sql("select max(a) as m from t").to_arrow_table()
1910+
assert pyarrow_table.to_pydict() == {"m": [3]}
1911+
1912+
1913+
def test_parquet_empty_batch_to_pyarrow(ctx, tmp_path):
1914+
path = tmp_path.joinpath("t.parquet")
1915+
1916+
ctx.sql("create table t_(a int not null)").collect()
1917+
ctx.sql("insert into t_ values (1), (2), (3)").collect()
1918+
ctx.sql(f"copy (select * from t_) to '{path}'").collect()
1919+
1920+
ctx.register_parquet("t", path)
1921+
pyarrow_table = ctx.sql("select * from t limit 0").to_arrow_table()
1922+
assert pyarrow_table.schema == pa.schema(
1923+
[
1924+
pa.field("a", pa.int32(), nullable=False),
1925+
]
1926+
)
1927+
1928+
1929+
def test_parquet_null_aggregation_to_pyarrow(ctx, tmp_path):
1930+
path = tmp_path.joinpath("t.parquet")
1931+
1932+
ctx.sql("create table t_(a int not null)").collect()
1933+
ctx.sql("insert into t_ values (1), (2), (3)").collect()
1934+
ctx.sql(f"copy (select * from t_) to '{path}'").collect()
1935+
1936+
ctx.register_parquet("t", path)
1937+
pyarrow_table = ctx.sql(
1938+
"select max(a) as m from (select * from t where a < 0)"
1939+
).to_arrow_table()
1940+
assert pyarrow_table.to_pydict() == {"m": [None]}
1941+
assert pyarrow_table.schema == pa.schema(
1942+
[
1943+
pa.field("m", pa.int32(), nullable=True),
1944+
]
1945+
)
1946+
1947+
19011948
def test_execute_stream(df):
19021949
stream = df.execute_stream()
19031950
assert all(batch is not None for batch in stream)

src/dataframe.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,11 +1044,18 @@ impl PyDataFrame {
10441044
/// Collect the batches and pass to Arrow Table
10451045
fn to_arrow_table(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
10461046
let batches = self.collect(py)?.into_pyobject(py)?;
1047-
let schema = self.schema().into_pyobject(py)?;
1047+
1048+
// only use the DataFrame's schema if there are no batches, otherwise let the schema be
1049+
// determined from the batches (avoids some inconsistencies with nullable columns)
1050+
let args = if batches.len()? == 0 {
1051+
let schema = self.schema().into_pyobject(py)?;
1052+
PyTuple::new(py, &[batches, schema])?
1053+
} else {
1054+
PyTuple::new(py, &[batches])?
1055+
};
10481056

10491057
// Instantiate pyarrow Table object and use its from_batches method
10501058
let table_class = py.import("pyarrow")?.getattr("Table")?;
1051-
let args = PyTuple::new(py, &[batches, schema])?;
10521059
let table: Py<PyAny> = table_class.call_method1("from_batches", args)?.into();
10531060
Ok(table)
10541061
}

0 commit comments

Comments
 (0)