@@ -107,41 +107,6 @@ def _is_array_with_nulls(pa_array: pa.Array) -> bool:
107
107
return pa_array .null_count > 0
108
108
109
109
110
- def _arrow_array_to_numpy (pa_array : pa .Array ) -> np .ndarray :
111
- if isinstance (pa_array , pa .ChunkedArray ):
112
- if isinstance (pa_array .type , _ArrayXDExtensionType ):
113
- # don't call to_pylist() to preserve dtype of the fixed-size array
114
- zero_copy_only = _is_zero_copy_only (pa_array .type .storage_dtype , unnest = True )
115
- array : List = [row for chunk in pa_array .chunks for row in chunk .to_numpy (zero_copy_only = zero_copy_only )]
116
- else :
117
- zero_copy_only = _is_zero_copy_only (pa_array .type ) and all (
118
- not _is_array_with_nulls (chunk ) for chunk in pa_array .chunks
119
- )
120
- array : List = [row for chunk in pa_array .chunks for row in chunk .to_numpy (zero_copy_only = zero_copy_only )]
121
- else :
122
- if isinstance (pa_array .type , _ArrayXDExtensionType ):
123
- # don't call to_pylist() to preserve dtype of the fixed-size array
124
- zero_copy_only = _is_zero_copy_only (pa_array .type .storage_dtype , unnest = True )
125
- array : List = pa_array .to_numpy (zero_copy_only = zero_copy_only )
126
- else :
127
- zero_copy_only = _is_zero_copy_only (pa_array .type ) and not _is_array_with_nulls (pa_array )
128
- array : List = pa_array .to_numpy (zero_copy_only = zero_copy_only ).tolist ()
129
-
130
- if len (array ) > 0 :
131
- if any (
132
- (isinstance (x , np .ndarray ) and (x .dtype == object or x .shape != array [0 ].shape ))
133
- or (isinstance (x , float ) and np .isnan (x ))
134
- for x in array
135
- ):
136
- if np .lib .NumpyVersion (np .__version__ ) >= "2.0.0b1" :
137
- return np .asarray (array , dtype = object )
138
- return np .array (array , copy = False , dtype = object )
139
- if np .lib .NumpyVersion (np .__version__ ) >= "2.0.0b1" :
140
- return np .asarray (array )
141
- else :
142
- return np .array (array , copy = False )
143
-
144
-
145
110
def dict_of_lists_to_list_of_dicts (dict_of_lists : Dict [str , List [T ]]) -> List [Dict [str , T ]]:
146
111
# convert to list of dicts
147
112
list_of_dicts = []
@@ -231,7 +196,41 @@ def extract_column(self, pa_table: pa.Table) -> np.ndarray:
231
196
return self ._arrow_array_to_numpy (pa_table [pa_table .column_names [0 ]])
232
197
233
198
def extract_batch (self , pa_table : pa .Table ) -> dict :
234
- return {col : _arrow_array_to_numpy (pa_table [col ]) for col in pa_table .column_names }
199
+ return {col : self ._arrow_array_to_numpy (pa_table [col ]) for col in pa_table .column_names }
200
+
201
+ def _arrow_array_to_numpy (self , pa_array : pa .Array ) -> np .ndarray :
202
+ if isinstance (pa_array , pa .ChunkedArray ):
203
+ if isinstance (pa_array .type , _ArrayXDExtensionType ):
204
+ # don't call to_pylist() to preserve dtype of the fixed-size array
205
+ zero_copy_only = _is_zero_copy_only (pa_array .type .storage_dtype , unnest = True )
206
+ array : List = [row for chunk in pa_array .chunks for row in chunk .to_numpy (zero_copy_only = zero_copy_only )]
207
+ else :
208
+ zero_copy_only = _is_zero_copy_only (pa_array .type ) and all (
209
+ not _is_array_with_nulls (chunk ) for chunk in pa_array .chunks
210
+ )
211
+ array : List = [row for chunk in pa_array .chunks for row in chunk .to_numpy (zero_copy_only = zero_copy_only )]
212
+ else :
213
+ if isinstance (pa_array .type , _ArrayXDExtensionType ):
214
+ # don't call to_pylist() to preserve dtype of the fixed-size array
215
+ zero_copy_only = _is_zero_copy_only (pa_array .type .storage_dtype , unnest = True )
216
+ array : List = pa_array .to_numpy (zero_copy_only = zero_copy_only )
217
+ else :
218
+ zero_copy_only = _is_zero_copy_only (pa_array .type ) and not _is_array_with_nulls (pa_array )
219
+ array : List = pa_array .to_numpy (zero_copy_only = zero_copy_only ).tolist ()
220
+
221
+ if len (array ) > 0 :
222
+ if any (
223
+ (isinstance (x , np .ndarray ) and (x .dtype == object or x .shape != array [0 ].shape ))
224
+ or (isinstance (x , float ) and np .isnan (x ))
225
+ for x in array
226
+ ):
227
+ if np .lib .NumpyVersion (np .__version__ ) >= "2.0.0b1" :
228
+ return np .asarray (array , dtype = object )
229
+ return np .array (array , copy = False , dtype = object )
230
+ if np .lib .NumpyVersion (np .__version__ ) >= "2.0.0b1" :
231
+ return np .asarray (array )
232
+ else :
233
+ return np .array (array , copy = False )
235
234
236
235
237
236
class PandasArrowExtractor (BaseArrowExtractor [pd .DataFrame , pd .Series , pd .DataFrame ]):
0 commit comments