Skip to content

Commit a89ef52

Browse files
committed
restore arrow array to numpy to numpy extractor
1 parent ac5a46d commit a89ef52

File tree

1 file changed

+35
-36
lines changed

1 file changed

+35
-36
lines changed

src/datasets/formatting/formatting.py

Lines changed: 35 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -107,41 +107,6 @@ def _is_array_with_nulls(pa_array: pa.Array) -> bool:
107107
return pa_array.null_count > 0
108108

109109

110-
def _arrow_array_to_numpy(pa_array: pa.Array) -> np.ndarray:
111-
if isinstance(pa_array, pa.ChunkedArray):
112-
if isinstance(pa_array.type, _ArrayXDExtensionType):
113-
# don't call to_pylist() to preserve dtype of the fixed-size array
114-
zero_copy_only = _is_zero_copy_only(pa_array.type.storage_dtype, unnest=True)
115-
array: List = [row for chunk in pa_array.chunks for row in chunk.to_numpy(zero_copy_only=zero_copy_only)]
116-
else:
117-
zero_copy_only = _is_zero_copy_only(pa_array.type) and all(
118-
not _is_array_with_nulls(chunk) for chunk in pa_array.chunks
119-
)
120-
array: List = [row for chunk in pa_array.chunks for row in chunk.to_numpy(zero_copy_only=zero_copy_only)]
121-
else:
122-
if isinstance(pa_array.type, _ArrayXDExtensionType):
123-
# don't call to_pylist() to preserve dtype of the fixed-size array
124-
zero_copy_only = _is_zero_copy_only(pa_array.type.storage_dtype, unnest=True)
125-
array: List = pa_array.to_numpy(zero_copy_only=zero_copy_only)
126-
else:
127-
zero_copy_only = _is_zero_copy_only(pa_array.type) and not _is_array_with_nulls(pa_array)
128-
array: List = pa_array.to_numpy(zero_copy_only=zero_copy_only).tolist()
129-
130-
if len(array) > 0:
131-
if any(
132-
(isinstance(x, np.ndarray) and (x.dtype == object or x.shape != array[0].shape))
133-
or (isinstance(x, float) and np.isnan(x))
134-
for x in array
135-
):
136-
if np.lib.NumpyVersion(np.__version__) >= "2.0.0b1":
137-
return np.asarray(array, dtype=object)
138-
return np.array(array, copy=False, dtype=object)
139-
if np.lib.NumpyVersion(np.__version__) >= "2.0.0b1":
140-
return np.asarray(array)
141-
else:
142-
return np.array(array, copy=False)
143-
144-
145110
def dict_of_lists_to_list_of_dicts(dict_of_lists: Dict[str, List[T]]) -> List[Dict[str, T]]:
146111
# convert to list of dicts
147112
list_of_dicts = []
@@ -231,7 +196,41 @@ def extract_column(self, pa_table: pa.Table) -> np.ndarray:
231196
return self._arrow_array_to_numpy(pa_table[pa_table.column_names[0]])
232197

233198
def extract_batch(self, pa_table: pa.Table) -> dict:
234-
return {col: _arrow_array_to_numpy(pa_table[col]) for col in pa_table.column_names}
199+
return {col: self._arrow_array_to_numpy(pa_table[col]) for col in pa_table.column_names}
200+
201+
def _arrow_array_to_numpy(self, pa_array: pa.Array) -> np.ndarray:
202+
if isinstance(pa_array, pa.ChunkedArray):
203+
if isinstance(pa_array.type, _ArrayXDExtensionType):
204+
# don't call to_pylist() to preserve dtype of the fixed-size array
205+
zero_copy_only = _is_zero_copy_only(pa_array.type.storage_dtype, unnest=True)
206+
array: List = [row for chunk in pa_array.chunks for row in chunk.to_numpy(zero_copy_only=zero_copy_only)]
207+
else:
208+
zero_copy_only = _is_zero_copy_only(pa_array.type) and all(
209+
not _is_array_with_nulls(chunk) for chunk in pa_array.chunks
210+
)
211+
array: List = [row for chunk in pa_array.chunks for row in chunk.to_numpy(zero_copy_only=zero_copy_only)]
212+
else:
213+
if isinstance(pa_array.type, _ArrayXDExtensionType):
214+
# don't call to_pylist() to preserve dtype of the fixed-size array
215+
zero_copy_only = _is_zero_copy_only(pa_array.type.storage_dtype, unnest=True)
216+
array: List = pa_array.to_numpy(zero_copy_only=zero_copy_only)
217+
else:
218+
zero_copy_only = _is_zero_copy_only(pa_array.type) and not _is_array_with_nulls(pa_array)
219+
array: List = pa_array.to_numpy(zero_copy_only=zero_copy_only).tolist()
220+
221+
if len(array) > 0:
222+
if any(
223+
(isinstance(x, np.ndarray) and (x.dtype == object or x.shape != array[0].shape))
224+
or (isinstance(x, float) and np.isnan(x))
225+
for x in array
226+
):
227+
if np.lib.NumpyVersion(np.__version__) >= "2.0.0b1":
228+
return np.asarray(array, dtype=object)
229+
return np.array(array, copy=False, dtype=object)
230+
if np.lib.NumpyVersion(np.__version__) >= "2.0.0b1":
231+
return np.asarray(array)
232+
else:
233+
return np.array(array, copy=False)
235234

236235

237236
class PandasArrowExtractor(BaseArrowExtractor[pd.DataFrame, pd.Series, pd.DataFrame]):

0 commit comments

Comments
 (0)