Skip to content

Commit c18ab05

Browse files
committed
POC: consistent NaN treatment for pyarrow dtypes
1 parent 3fb47c7 commit c18ab05

File tree

7 files changed

+81
-19
lines changed

7 files changed

+81
-19
lines changed

pandas/_libs/parsers.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1453,7 +1453,7 @@ def _maybe_upcast(
14531453
if isinstance(arr, IntegerArray) and arr.isna().all():
14541454
# use null instead of int64 in pyarrow
14551455
arr = arr.to_numpy(na_value=None)
1456-
arr = ArrowExtensionArray(pa.array(arr, from_pandas=True))
1456+
arr = ArrowExtensionArray(pa.array(arr))
14571457

14581458
return arr
14591459

pandas/core/arrays/arrow/array.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import numpy as np
1717

1818
from pandas._libs import lib
19+
from pandas._libs.missing import NA
1920
from pandas._libs.tslibs import (
2021
Timedelta,
2122
Timestamp,
@@ -351,7 +352,7 @@ def _from_sequence_of_strings(
351352
# duration to string casting behavior
352353
mask = isna(scalars)
353354
if not isinstance(strings, (pa.Array, pa.ChunkedArray)):
354-
strings = pa.array(strings, type=pa.string(), from_pandas=True)
355+
strings = pa.array(strings, type=pa.string())
355356
strings = pc.if_else(mask, None, strings)
356357
try:
357358
scalars = strings.cast(pa.int64())
@@ -372,7 +373,7 @@ def _from_sequence_of_strings(
372373
if isinstance(strings, (pa.Array, pa.ChunkedArray)):
373374
scalars = strings
374375
else:
375-
scalars = pa.array(strings, type=pa.string(), from_pandas=True)
376+
scalars = pa.array(strings, type=pa.string())
376377
scalars = pc.if_else(pc.equal(scalars, "1.0"), "1", scalars)
377378
scalars = pc.if_else(pc.equal(scalars, "0.0"), "0", scalars)
378379
scalars = scalars.cast(pa.bool_())
@@ -384,6 +385,13 @@ def _from_sequence_of_strings(
384385
from pandas.core.tools.numeric import to_numeric
385386

386387
scalars = to_numeric(strings, errors="raise")
388+
if not pa.types.is_decimal(pa_type):
389+
# TODO: figure out why doing this cast breaks with decimal dtype
390+
# in test_from_sequence_of_strings_pa_array
391+
mask = strings.is_null()
392+
scalars = pa.array(scalars, mask=np.array(mask), type=pa_type)
393+
# TODO: could we just do strings.cast(pa_type)?
394+
387395
else:
388396
raise NotImplementedError(
389397
f"Converting strings to {pa_type} is not implemented."
@@ -426,7 +434,7 @@ def _box_pa_scalar(cls, value, pa_type: pa.DataType | None = None) -> pa.Scalar:
426434
"""
427435
if isinstance(value, pa.Scalar):
428436
pa_scalar = value
429-
elif isna(value):
437+
elif isna(value) and not lib.is_float(value):
430438
pa_scalar = pa.scalar(None, type=pa_type)
431439
else:
432440
# Workaround https://github.com/apache/arrow/issues/37291
@@ -443,7 +451,7 @@ def _box_pa_scalar(cls, value, pa_type: pa.DataType | None = None) -> pa.Scalar:
443451
value = value.as_unit(pa_type.unit)
444452
value = value._value
445453

446-
pa_scalar = pa.scalar(value, type=pa_type, from_pandas=True)
454+
pa_scalar = pa.scalar(value, type=pa_type)
447455

448456
if pa_type is not None and pa_scalar.type != pa_type:
449457
pa_scalar = pa_scalar.cast(pa_type)
@@ -475,6 +483,13 @@ def _box_pa_array(
475483
if copy:
476484
value = value.copy()
477485
pa_array = value.__arrow_array__()
486+
487+
elif hasattr(value, "__arrow_array__"):
488+
# e.g. StringArray
489+
if copy:
490+
value = value.copy()
491+
pa_array = value.__arrow_array__()
492+
478493
else:
479494
if (
480495
isinstance(value, np.ndarray)
@@ -528,19 +543,32 @@ def _box_pa_array(
528543
pa_array = pa.array(dta._ndarray, type=pa_type, mask=dta_mask)
529544
return pa_array
530545

546+
mask = None
547+
if getattr(value, "dtype", None) is None or value.dtype.kind not in "mfM":
548+
# similar to isna(value) but exclude NaN
549+
# TODO: cythonize!
550+
mask = np.array([x is NA or x is None for x in value], dtype=bool)
551+
552+
from_pandas = False
553+
if pa.types.is_integer(pa_type):
554+
# If user specifically asks to cast a numpy float array with NaNs
555+
# to pyarrow integer, we'll treat those NaNs as NA
556+
from_pandas = True
531557
try:
532-
pa_array = pa.array(value, type=pa_type, from_pandas=True)
558+
pa_array = pa.array(
559+
value, type=pa_type, mask=mask, from_pandas=from_pandas
560+
)
533561
except (pa.ArrowInvalid, pa.ArrowTypeError):
534562
# GH50430: let pyarrow infer type, then cast
535-
pa_array = pa.array(value, from_pandas=True)
563+
pa_array = pa.array(value, mask=mask, from_pandas=from_pandas)
536564

537565
if pa_type is None and pa.types.is_duration(pa_array.type):
538566
# Workaround https://github.com/apache/arrow/issues/37291
539567
from pandas.core.tools.timedeltas import to_timedelta
540568

541569
value = to_timedelta(value)
542570
value = value.to_numpy()
543-
pa_array = pa.array(value, type=pa_type, from_pandas=True)
571+
pa_array = pa.array(value, type=pa_type)
544572

545573
if pa.types.is_duration(pa_array.type) and pa_array.null_count > 0:
546574
# GH52843: upstream bug for duration types when originally
@@ -1187,7 +1215,7 @@ def isin(self, values: ArrayLike) -> npt.NDArray[np.bool_]:
11871215
if not len(values):
11881216
return np.zeros(len(self), dtype=bool)
11891217

1190-
result = pc.is_in(self._pa_array, value_set=pa.array(values, from_pandas=True))
1218+
result = pc.is_in(self._pa_array, value_set=pa.array(values))
11911219
# pyarrow 2.0.0 returned nulls, so we explicitly specify dtype to convert nulls
11921220
# to False
11931221
return np.array(result, dtype=np.bool_)
@@ -1994,7 +2022,7 @@ def __setitem__(self, key, value) -> None:
19942022
raise ValueError("Length of indexer and values mismatch")
19952023
chunks = [
19962024
*self._pa_array[:key].chunks,
1997-
pa.array([value], type=self._pa_array.type, from_pandas=True),
2025+
pa.array([value], type=self._pa_array.type),
19982026
*self._pa_array[key + 1 :].chunks,
19992027
]
20002028
data = pa.chunked_array(chunks).combine_chunks()
@@ -2048,7 +2076,7 @@ def _rank_calc(
20482076
pa_type = pa.float64()
20492077
else:
20502078
pa_type = pa.uint64()
2051-
result = pa.array(ranked, type=pa_type, from_pandas=True)
2079+
result = pa.array(ranked, type=pa_type)
20522080
return result
20532081

20542082
data = self._pa_array.combine_chunks()
@@ -2300,7 +2328,7 @@ def _to_numpy_and_type(value) -> tuple[np.ndarray, pa.DataType | None]:
23002328
right, right_type = _to_numpy_and_type(right)
23012329
pa_type = left_type or right_type
23022330
result = np.where(cond, left, right)
2303-
return pa.array(result, type=pa_type, from_pandas=True)
2331+
return pa.array(result, type=pa_type)
23042332

23052333
@classmethod
23062334
def _replace_with_mask(
@@ -2343,7 +2371,7 @@ def _replace_with_mask(
23432371
replacements = replacements.as_py()
23442372
result = np.array(values, dtype=object)
23452373
result[mask] = replacements
2346-
return pa.array(result, type=values.type, from_pandas=True)
2374+
return pa.array(result, type=values.type)
23472375

23482376
# ------------------------------------------------------------------
23492377
# GroupBy Methods
@@ -2422,7 +2450,7 @@ def _groupby_op(
24222450
return type(self)(pa_result)
24232451
else:
24242452
# DatetimeArray, TimedeltaArray
2425-
pa_result = pa.array(result, from_pandas=True)
2453+
pa_result = pa.array(result)
24262454
return type(self)(pa_result)
24272455

24282456
def _apply_elementwise(self, func: Callable) -> list[list[Any]]:

pandas/core/arrays/string_.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,12 @@ def _str_map_str_or_object(
481481
if self.dtype.storage == "pyarrow":
482482
import pyarrow as pa
483483

484+
# TODO: shouldn't this already be caught my passed mask?
485+
# it isn't in test_extract_expand_capture_groups_index
486+
# mask = mask | np.array(
487+
# [x is libmissing.NA for x in result], dtype=bool
488+
# )
489+
484490
result = pa.array(
485491
result, mask=mask, type=pa.large_string(), from_pandas=True
486492
)
@@ -733,7 +739,7 @@ def __arrow_array__(self, type=None):
733739

734740
values = self._ndarray.copy()
735741
values[self.isna()] = None
736-
return pa.array(values, type=type, from_pandas=True)
742+
return pa.array(values, type=type)
737743

738744
def _values_for_factorize(self) -> tuple[np.ndarray, libmissing.NAType | float]: # type: ignore[override]
739745
arr = self._ndarray

pandas/core/generic.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9873,7 +9873,7 @@ def where(
98739873
def where(
98749874
self,
98759875
cond,
9876-
other=np.nan,
9876+
other=lib.no_default,
98779877
*,
98789878
inplace: bool = False,
98799879
axis: Axis | None = None,
@@ -10031,6 +10031,23 @@ def where(
1003110031
stacklevel=2,
1003210032
)
1003310033

10034+
if other is lib.no_default:
10035+
if self.ndim == 1:
10036+
if isinstance(self.dtype, ExtensionDtype):
10037+
other = self.dtype.na_value
10038+
else:
10039+
other = np.nan
10040+
else:
10041+
if self._mgr.nblocks == 1 and isinstance(
10042+
self._mgr.blocks[0].values.dtype, ExtensionDtype
10043+
):
10044+
# FIXME: checking this is kludgy!
10045+
other = self._mgr.blocks[0].values.dtype.na_value
10046+
else:
10047+
# FIXME: the same problem we had with Series will now
10048+
# show up column-by-column!
10049+
other = np.nan
10050+
1003410051
other = common.apply_if_callable(other, self)
1003510052
return self._where(cond, other, inplace=inplace, axis=axis, level=level)
1003610053

pandas/tests/extension/test_arrow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -721,7 +721,7 @@ def test_EA_types(self, engine, data, dtype_backend, request):
721721
pytest.mark.xfail(reason="CSV parsers don't correctly handle binary")
722722
)
723723
df = pd.DataFrame({"with_dtype": pd.Series(data, dtype=str(data.dtype))})
724-
csv_output = df.to_csv(index=False, na_rep=np.nan)
724+
csv_output = df.to_csv(index=False, na_rep=np.nan) # should be NA?
725725
if pa.types.is_binary(pa_dtype):
726726
csv_output = BytesIO(csv_output)
727727
else:

pandas/tests/groupby/test_reductions.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,8 +381,10 @@ def test_first_last_skipna(any_real_nullable_dtype, sort, skipna, how):
381381
df = DataFrame(
382382
{
383383
"a": [2, 1, 1, 2, 3, 3],
384-
"b": [na_value, 3.0, na_value, 4.0, np.nan, np.nan],
385-
"c": [na_value, 3.0, na_value, 4.0, np.nan, np.nan],
384+
# TODO: test that has mixed na_value and NaN either working for
385+
# float or raising for int?
386+
"b": [na_value, 3.0, na_value, 4.0, na_value, na_value],
387+
"c": [na_value, 3.0, na_value, 4.0, na_value, na_value],
386388
},
387389
dtype=any_real_nullable_dtype,
388390
)

pandas/tests/series/methods/test_rank.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,13 @@ def test_rank_tie_methods(self, ser, results, dtype, using_infer_string):
276276

277277
ser = ser if dtype is None else ser.astype(dtype)
278278
result = ser.rank(method=method)
279+
if dtype == "float64[pyarrow]":
280+
# the NaNs are not treated as NA
281+
exp = exp.copy()
282+
if method == "average":
283+
exp[np.isnan(ser)] = 9.5
284+
elif method == "dense":
285+
exp[np.isnan(ser)] = 6
279286
tm.assert_series_equal(result, Series(exp, dtype=expected_dtype(dtype, method)))
280287

281288
@pytest.mark.parametrize("na_option", ["top", "bottom", "keep"])
@@ -321,6 +328,8 @@ def test_rank_tie_methods_on_infs_nans(
321328
order = [ranks[1], ranks[0], ranks[2]]
322329
elif na_option == "bottom":
323330
order = [ranks[0], ranks[2], ranks[1]]
331+
elif dtype == "float64[pyarrow]":
332+
order = [ranks[0], [NA] * chunk, ranks[1]]
324333
else:
325334
order = [ranks[0], [np.nan] * chunk, ranks[1]]
326335
expected = order if ascending else order[::-1]

0 commit comments

Comments
 (0)