Skip to content

Commit 88c7663

Browse files
committed
POC: consistent NaN treatment for pyarrow dtypes
1 parent fddad79 commit 88c7663

File tree

7 files changed

+81
-19
lines changed

7 files changed

+81
-19
lines changed

pandas/_libs/parsers.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1456,7 +1456,7 @@ def _maybe_upcast(
14561456
if isinstance(arr, IntegerArray) and arr.isna().all():
14571457
# use null instead of int64 in pyarrow
14581458
arr = arr.to_numpy(na_value=None)
1459-
arr = ArrowExtensionArray(pa.array(arr, from_pandas=True))
1459+
arr = ArrowExtensionArray(pa.array(arr))
14601460

14611461
return arr
14621462

pandas/core/arrays/arrow/array.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import numpy as np
1717

1818
from pandas._libs import lib
19+
from pandas._libs.missing import NA
1920
from pandas._libs.tslibs import (
2021
Timedelta,
2122
Timestamp,
@@ -351,7 +352,7 @@ def _from_sequence_of_strings(
351352
# duration to string casting behavior
352353
mask = isna(scalars)
353354
if not isinstance(strings, (pa.Array, pa.ChunkedArray)):
354-
strings = pa.array(strings, type=pa.string(), from_pandas=True)
355+
strings = pa.array(strings, type=pa.string())
355356
strings = pc.if_else(mask, None, strings)
356357
try:
357358
scalars = strings.cast(pa.int64())
@@ -372,7 +373,7 @@ def _from_sequence_of_strings(
372373
if isinstance(strings, (pa.Array, pa.ChunkedArray)):
373374
scalars = strings
374375
else:
375-
scalars = pa.array(strings, type=pa.string(), from_pandas=True)
376+
scalars = pa.array(strings, type=pa.string())
376377
scalars = pc.if_else(pc.equal(scalars, "1.0"), "1", scalars)
377378
scalars = pc.if_else(pc.equal(scalars, "0.0"), "0", scalars)
378379
scalars = scalars.cast(pa.bool_())
@@ -384,6 +385,13 @@ def _from_sequence_of_strings(
384385
from pandas.core.tools.numeric import to_numeric
385386

386387
scalars = to_numeric(strings, errors="raise")
388+
if not pa.types.is_decimal(pa_type):
389+
# TODO: figure out why doing this cast breaks with decimal dtype
390+
# in test_from_sequence_of_strings_pa_array
391+
mask = strings.is_null()
392+
scalars = pa.array(scalars, mask=np.array(mask), type=pa_type)
393+
# TODO: could we just do strings.cast(pa_type)?
394+
387395
else:
388396
raise NotImplementedError(
389397
f"Converting strings to {pa_type} is not implemented."
@@ -426,7 +434,7 @@ def _box_pa_scalar(cls, value, pa_type: pa.DataType | None = None) -> pa.Scalar:
426434
"""
427435
if isinstance(value, pa.Scalar):
428436
pa_scalar = value
429-
elif isna(value):
437+
elif isna(value) and not lib.is_float(value):
430438
pa_scalar = pa.scalar(None, type=pa_type)
431439
else:
432440
# Workaround https://github.com/apache/arrow/issues/37291
@@ -443,7 +451,7 @@ def _box_pa_scalar(cls, value, pa_type: pa.DataType | None = None) -> pa.Scalar:
443451
value = value.as_unit(pa_type.unit)
444452
value = value._value
445453

446-
pa_scalar = pa.scalar(value, type=pa_type, from_pandas=True)
454+
pa_scalar = pa.scalar(value, type=pa_type)
447455

448456
if pa_type is not None and pa_scalar.type != pa_type:
449457
pa_scalar = pa_scalar.cast(pa_type)
@@ -475,6 +483,13 @@ def _box_pa_array(
475483
if copy:
476484
value = value.copy()
477485
pa_array = value.__arrow_array__()
486+
487+
elif hasattr(value, "__arrow_array__"):
488+
# e.g. StringArray
489+
if copy:
490+
value = value.copy()
491+
pa_array = value.__arrow_array__()
492+
478493
else:
479494
if (
480495
isinstance(value, np.ndarray)
@@ -528,19 +543,32 @@ def _box_pa_array(
528543
pa_array = pa.array(dta._ndarray, type=pa_type, mask=mask)
529544
return pa_array
530545

546+
mask = None
547+
if getattr(value, "dtype", None) is None or value.dtype.kind not in "mfM":
548+
# similar to isna(value) but exclude NaN
549+
# TODO: cythonize!
550+
mask = np.array([x is NA or x is None for x in value], dtype=bool)
551+
552+
from_pandas = False
553+
if pa.types.is_integer(pa_type):
554+
# If user specifically asks to cast a numpy float array with NaNs
555+
# to pyarrow integer, we'll treat those NaNs as NA
556+
from_pandas = True
531557
try:
532-
pa_array = pa.array(value, type=pa_type, from_pandas=True)
558+
pa_array = pa.array(
559+
value, type=pa_type, mask=mask, from_pandas=from_pandas
560+
)
533561
except (pa.ArrowInvalid, pa.ArrowTypeError):
534562
# GH50430: let pyarrow infer type, then cast
535-
pa_array = pa.array(value, from_pandas=True)
563+
pa_array = pa.array(value, mask=mask, from_pandas=from_pandas)
536564

537565
if pa_type is None and pa.types.is_duration(pa_array.type):
538566
# Workaround https://github.com/apache/arrow/issues/37291
539567
from pandas.core.tools.timedeltas import to_timedelta
540568

541569
value = to_timedelta(value)
542570
value = value.to_numpy()
543-
pa_array = pa.array(value, type=pa_type, from_pandas=True)
571+
pa_array = pa.array(value, type=pa_type)
544572

545573
if pa.types.is_duration(pa_array.type) and pa_array.null_count > 0:
546574
# GH52843: upstream bug for duration types when originally
@@ -1187,7 +1215,7 @@ def isin(self, values: ArrayLike) -> npt.NDArray[np.bool_]:
11871215
if not len(values):
11881216
return np.zeros(len(self), dtype=bool)
11891217

1190-
result = pc.is_in(self._pa_array, value_set=pa.array(values, from_pandas=True))
1218+
result = pc.is_in(self._pa_array, value_set=pa.array(values))
11911219
# pyarrow 2.0.0 returned nulls, so we explicitly specify dtype to convert nulls
11921220
# to False
11931221
return np.array(result, dtype=np.bool_)
@@ -1992,7 +2020,7 @@ def __setitem__(self, key, value) -> None:
19922020
raise ValueError("Length of indexer and values mismatch")
19932021
chunks = [
19942022
*self._pa_array[:key].chunks,
1995-
pa.array([value], type=self._pa_array.type, from_pandas=True),
2023+
pa.array([value], type=self._pa_array.type),
19962024
*self._pa_array[key + 1 :].chunks,
19972025
]
19982026
data = pa.chunked_array(chunks).combine_chunks()
@@ -2046,7 +2074,7 @@ def _rank_calc(
20462074
pa_type = pa.float64()
20472075
else:
20482076
pa_type = pa.uint64()
2049-
result = pa.array(ranked, type=pa_type, from_pandas=True)
2077+
result = pa.array(ranked, type=pa_type)
20502078
return result
20512079

20522080
data = self._pa_array.combine_chunks()
@@ -2298,7 +2326,7 @@ def _to_numpy_and_type(value) -> tuple[np.ndarray, pa.DataType | None]:
22982326
right, right_type = _to_numpy_and_type(right)
22992327
pa_type = left_type or right_type
23002328
result = np.where(cond, left, right)
2301-
return pa.array(result, type=pa_type, from_pandas=True)
2329+
return pa.array(result, type=pa_type)
23022330

23032331
@classmethod
23042332
def _replace_with_mask(
@@ -2341,7 +2369,7 @@ def _replace_with_mask(
23412369
replacements = replacements.as_py()
23422370
result = np.array(values, dtype=object)
23432371
result[mask] = replacements
2344-
return pa.array(result, type=values.type, from_pandas=True)
2372+
return pa.array(result, type=values.type)
23452373

23462374
# ------------------------------------------------------------------
23472375
# GroupBy Methods
@@ -2420,7 +2448,7 @@ def _groupby_op(
24202448
return type(self)(pa_result)
24212449
else:
24222450
# DatetimeArray, TimedeltaArray
2423-
pa_result = pa.array(result, from_pandas=True)
2451+
pa_result = pa.array(result)
24242452
return type(self)(pa_result)
24252453

24262454
def _apply_elementwise(self, func: Callable) -> list[list[Any]]:

pandas/core/arrays/string_.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,12 @@ def _str_map_str_or_object(
474474
if self.dtype.storage == "pyarrow":
475475
import pyarrow as pa
476476

477+
# TODO: shouldn't this already be caught my passed mask?
478+
# it isn't in test_extract_expand_capture_groups_index
479+
# mask = mask | np.array(
480+
# [x is libmissing.NA for x in result], dtype=bool
481+
# )
482+
477483
result = pa.array(
478484
result, mask=mask, type=pa.large_string(), from_pandas=True
479485
)
@@ -726,7 +732,7 @@ def __arrow_array__(self, type=None):
726732

727733
values = self._ndarray.copy()
728734
values[self.isna()] = None
729-
return pa.array(values, type=type, from_pandas=True)
735+
return pa.array(values, type=type)
730736

731737
def _values_for_factorize(self) -> tuple[np.ndarray, libmissing.NAType | float]: # type: ignore[override]
732738
arr = self._ndarray

pandas/core/generic.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9874,7 +9874,7 @@ def where(
98749874
def where(
98759875
self,
98769876
cond,
9877-
other=np.nan,
9877+
other=lib.no_default,
98789878
*,
98799879
inplace: bool = False,
98809880
axis: Axis | None = None,
@@ -10032,6 +10032,23 @@ def where(
1003210032
stacklevel=2,
1003310033
)
1003410034

10035+
if other is lib.no_default:
10036+
if self.ndim == 1:
10037+
if isinstance(self.dtype, ExtensionDtype):
10038+
other = self.dtype.na_value
10039+
else:
10040+
other = np.nan
10041+
else:
10042+
if self._mgr.nblocks == 1 and isinstance(
10043+
self._mgr.blocks[0].values.dtype, ExtensionDtype
10044+
):
10045+
# FIXME: checking this is kludgy!
10046+
other = self._mgr.blocks[0].values.dtype.na_value
10047+
else:
10048+
# FIXME: the same problem we had with Series will now
10049+
# show up column-by-column!
10050+
other = np.nan
10051+
1003510052
other = common.apply_if_callable(other, self)
1003610053
return self._where(cond, other, inplace=inplace, axis=axis, level=level)
1003710054

pandas/tests/extension/test_arrow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -717,7 +717,7 @@ def test_EA_types(self, engine, data, dtype_backend, request):
717717
pytest.mark.xfail(reason="CSV parsers don't correctly handle binary")
718718
)
719719
df = pd.DataFrame({"with_dtype": pd.Series(data, dtype=str(data.dtype))})
720-
csv_output = df.to_csv(index=False, na_rep=np.nan)
720+
csv_output = df.to_csv(index=False, na_rep=np.nan) # should be NA?
721721
if pa.types.is_binary(pa_dtype):
722722
csv_output = BytesIO(csv_output)
723723
else:

pandas/tests/groupby/test_reductions.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,8 +381,10 @@ def test_first_last_skipna(any_real_nullable_dtype, sort, skipna, how):
381381
df = DataFrame(
382382
{
383383
"a": [2, 1, 1, 2, 3, 3],
384-
"b": [na_value, 3.0, na_value, 4.0, np.nan, np.nan],
385-
"c": [na_value, 3.0, na_value, 4.0, np.nan, np.nan],
384+
# TODO: test that has mixed na_value and NaN either working for
385+
# float or raising for int?
386+
"b": [na_value, 3.0, na_value, 4.0, na_value, na_value],
387+
"c": [na_value, 3.0, na_value, 4.0, na_value, na_value],
386388
},
387389
dtype=any_real_nullable_dtype,
388390
)

pandas/tests/series/methods/test_rank.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,13 @@ def test_rank_tie_methods(self, ser, results, dtype, using_infer_string):
276276

277277
ser = ser if dtype is None else ser.astype(dtype)
278278
result = ser.rank(method=method)
279+
if dtype == "float64[pyarrow]":
280+
# the NaNs are not treated as NA
281+
exp = exp.copy()
282+
if method == "average":
283+
exp[np.isnan(ser)] = 9.5
284+
elif method == "dense":
285+
exp[np.isnan(ser)] = 6
279286
tm.assert_series_equal(result, Series(exp, dtype=expected_dtype(dtype, method)))
280287

281288
@pytest.mark.parametrize("na_option", ["top", "bottom", "keep"])
@@ -321,6 +328,8 @@ def test_rank_tie_methods_on_infs_nans(
321328
order = [ranks[1], ranks[0], ranks[2]]
322329
elif na_option == "bottom":
323330
order = [ranks[0], ranks[2], ranks[1]]
331+
elif dtype == "float64[pyarrow]":
332+
order = [ranks[0], [NA] * chunk, ranks[1]]
324333
else:
325334
order = [ranks[0], [np.nan] * chunk, ranks[1]]
326335
expected = order if ascending else order[::-1]

0 commit comments

Comments
 (0)