15
15
# See the License for the specific language governing permissions and
16
16
# limitations under the License.
17
17
################################################################################
18
+ import os
18
19
19
20
# pypaimon.api implementation based on Java code & py4j lib
20
21
30
31
table_commit , Schema , predicate )
31
32
from typing import List , Iterator , Optional , Any , TYPE_CHECKING
32
33
34
+ from pypaimon .pynative .common .exception import PyNativeNotImplementedError
35
+ from pypaimon .pynative .common .predicate import PyNativePredicate
36
+ from pypaimon .pynative .common .row .internal_row import InternalRow
37
+ from pypaimon .pynative .util .reader_converter import ReaderConverter
38
+
33
39
if TYPE_CHECKING :
34
40
import ray
35
41
from duckdb .duckdb import DuckDBPyConnection
@@ -72,7 +78,12 @@ def __init__(self, j_table, catalog_options: dict):
72
78
73
79
def new_read_builder (self ) -> 'ReadBuilder' :
74
80
j_read_builder = get_gateway ().jvm .InvocationUtil .getReadBuilder (self ._j_table )
75
- return ReadBuilder (j_read_builder , self ._j_table .rowType (), self ._catalog_options )
81
+ if self ._j_table .primaryKeys ().isEmpty ():
82
+ primary_keys = None
83
+ else :
84
+ primary_keys = [str (key ) for key in self ._j_table .primaryKeys ()]
85
+ return ReadBuilder (j_read_builder , self ._j_table .rowType (), self ._catalog_options ,
86
+ primary_keys )
76
87
77
88
def new_batch_write_builder (self ) -> 'BatchWriteBuilder' :
78
89
java_utils .check_batch_write (self ._j_table )
@@ -82,16 +93,21 @@ def new_batch_write_builder(self) -> 'BatchWriteBuilder':
82
93
83
94
class ReadBuilder (read_builder .ReadBuilder ):
84
95
85
- def __init__ (self , j_read_builder , j_row_type , catalog_options : dict ):
96
+ def __init__ (self , j_read_builder , j_row_type , catalog_options : dict , primary_keys : List [ str ] ):
86
97
self ._j_read_builder = j_read_builder
87
98
self ._j_row_type = j_row_type
88
99
self ._catalog_options = catalog_options
100
+ self ._primary_keys = primary_keys
101
+ self ._predicate = None
102
+ self ._projection = None
89
103
90
104
def with_filter (self , predicate : 'Predicate' ):
105
+ self ._predicate = predicate
91
106
self ._j_read_builder .withFilter (predicate .to_j_predicate ())
92
107
return self
93
108
94
109
def with_projection (self , projection : List [str ]) -> 'ReadBuilder' :
110
+ self ._projection = projection
95
111
field_names = list (map (lambda field : field .name (), self ._j_row_type .getFields ()))
96
112
int_projection = list (map (lambda p : field_names .index (p ), projection ))
97
113
gateway = get_gateway ()
@@ -111,7 +127,8 @@ def new_scan(self) -> 'TableScan':
111
127
112
128
def new_read (self ) -> 'TableRead' :
113
129
j_table_read = self ._j_read_builder .newRead ().executeFilter ()
114
- return TableRead (j_table_read , self ._j_read_builder .readType (), self ._catalog_options )
130
+ return TableRead (j_table_read , self ._j_read_builder .readType (), self ._catalog_options ,
131
+ self ._predicate , self ._projection , self ._primary_keys )
115
132
116
133
def new_predicate_builder (self ) -> 'PredicateBuilder' :
117
134
return PredicateBuilder (self ._j_row_type )
@@ -185,14 +202,29 @@ def file_paths(self) -> List[str]:
185
202
186
203
class TableRead (table_read .TableRead ):
187
204
188
- def __init__ (self , j_table_read , j_read_type , catalog_options ):
205
+ def __init__ (self , j_table_read , j_read_type , catalog_options , predicate , projection ,
206
+ primary_keys : List [str ]):
207
+ self ._j_table_read = j_table_read
208
+ self ._j_read_type = j_read_type
209
+ self ._catalog_options = catalog_options
210
+
211
+ self ._predicate = predicate
212
+ self ._projection = projection
213
+ self ._primary_keys = primary_keys
214
+
189
215
self ._arrow_schema = java_utils .to_arrow_schema (j_read_type )
190
216
self ._j_bytes_reader = get_gateway ().jvm .InvocationUtil .createParallelBytesReader (
191
217
j_table_read , j_read_type , TableRead ._get_max_workers (catalog_options ))
192
218
193
- def to_arrow (self , splits ):
194
- record_batch_reader = self .to_arrow_batch_reader (splits )
195
- return pa .Table .from_batches (record_batch_reader , schema = self ._arrow_schema )
219
+ def to_arrow (self , splits : List ['Split' ]) -> pa .Table :
220
+ record_generator = self .to_record_generator (splits )
221
+
222
+ # If necessary, set the env constants.IMPLEMENT_MODE to 'py4j' to forcibly use py4j reader
223
+ if os .environ .get (constants .IMPLEMENT_MODE , '' ) != 'py4j' and record_generator is not None :
224
+ return TableRead ._iterator_to_pyarrow_table (record_generator , self ._arrow_schema )
225
+ else :
226
+ record_batch_reader = self .to_arrow_batch_reader (splits )
227
+ return pa .Table .from_batches (record_batch_reader , schema = self ._arrow_schema )
196
228
197
229
def to_arrow_batch_reader (self , splits ):
198
230
j_splits = list (map (lambda s : s .to_j_split (), splits ))
@@ -219,6 +251,60 @@ def to_ray(self, splits: List[Split]) -> "ray.data.dataset.Dataset":
219
251
220
252
return ray .data .from_arrow (self .to_arrow (splits ))
221
253
254
+ def to_record_generator (self , splits : List ['Split' ]) -> Optional [Iterator [Any ]]:
255
+ """
256
+ Returns a generator for iterating over records in the table.
257
+ If pynative reader is not available, returns None.
258
+ """
259
+ try :
260
+ j_splits = list (s .to_j_split () for s in splits )
261
+ j_reader = get_gateway ().jvm .InvocationUtil .createReader (self ._j_table_read , j_splits )
262
+ converter = ReaderConverter (self ._predicate , self ._projection , self ._primary_keys )
263
+ pynative_reader = converter .convert_java_reader (j_reader )
264
+
265
+ def _record_generator ():
266
+ try :
267
+ batch = pynative_reader .read_batch ()
268
+ while batch is not None :
269
+ record = batch .next ()
270
+ while record is not None :
271
+ yield record
272
+ record = batch .next ()
273
+ batch .release_batch ()
274
+ batch = pynative_reader .read_batch ()
275
+ finally :
276
+ pynative_reader .close ()
277
+
278
+ return _record_generator ()
279
+
280
+ except PyNativeNotImplementedError as e :
281
+ print (f"Generating pynative reader failed, will use py4j reader instead, "
282
+ f"error message: { str (e )} " )
283
+ return None
284
+
285
+ @staticmethod
286
+ def _iterator_to_pyarrow_table (record_generator , arrow_schema ):
287
+ """
288
+ Converts a record generator into a pyarrow Table using the provided Arrow schema.
289
+ """
290
+ record_batches = []
291
+ current_batch = []
292
+ batch_size = 1024 # Can be adjusted according to needs for batch size
293
+
294
+ for record in record_generator :
295
+ record_dict = {field : record .get_field (i ) for i , field in enumerate (arrow_schema .names )}
296
+ current_batch .append (record_dict )
297
+ if len (current_batch ) >= batch_size :
298
+ batch = pa .RecordBatch .from_pylist (current_batch , schema = arrow_schema )
299
+ record_batches .append (batch )
300
+ current_batch = []
301
+
302
+ if current_batch :
303
+ batch = pa .RecordBatch .from_pylist (current_batch , schema = arrow_schema )
304
+ record_batches .append (batch )
305
+
306
+ return pa .Table .from_batches (record_batches , schema = arrow_schema )
307
+
222
308
@staticmethod
223
309
def _get_max_workers (catalog_options ):
224
310
# default is sequential
@@ -317,12 +403,16 @@ def close(self):
317
403
318
404
class Predicate (predicate .Predicate ):
319
405
320
- def __init__ (self , j_predicate_bytes ):
406
+ def __init__ (self , py_predicate : PyNativePredicate , j_predicate_bytes ):
407
+ self .py_predicate = py_predicate
321
408
self ._j_predicate_bytes = j_predicate_bytes
322
409
323
410
def to_j_predicate (self ):
324
411
return deserialize_java_object (self ._j_predicate_bytes )
325
412
413
+ def test (self , record : InternalRow ) -> bool :
414
+ return self .py_predicate .test (record )
415
+
326
416
327
417
class PredicateBuilder (predicate .PredicateBuilder ):
328
418
@@ -350,7 +440,8 @@ def _build(self, method: str, field: str, literals: Optional[List[Any]] = None):
350
440
index ,
351
441
literals
352
442
)
353
- return Predicate (serialize_java_object (j_predicate ))
443
+ return Predicate (PyNativePredicate (method , index , field , literals ),
444
+ serialize_java_object (j_predicate ))
354
445
355
446
def equal (self , field : str , literal : Any ) -> Predicate :
356
447
return self ._build ('equal' , field , [literal ])
@@ -396,11 +487,13 @@ def between(self, field: str, included_lower_bound: Any, included_upper_bound: A
396
487
return self ._build ('between' , field , [included_lower_bound , included_upper_bound ])
397
488
398
489
def and_predicates (self , predicates : List [Predicate ]) -> Predicate :
399
- predicates = list (map (lambda p : p .to_j_predicate (), predicates ))
400
- j_predicate = get_gateway ().jvm .PredicationUtil .buildAnd (predicates )
401
- return Predicate (serialize_java_object (j_predicate ))
490
+ j_predicates = list (map (lambda p : p .to_j_predicate (), predicates ))
491
+ j_predicate = get_gateway ().jvm .PredicationUtil .buildAnd (j_predicates )
492
+ return Predicate (PyNativePredicate ('and' , None , None , predicates ),
493
+ serialize_java_object (j_predicate ))
402
494
403
495
def or_predicates (self , predicates : List [Predicate ]) -> Predicate :
404
- predicates = list (map (lambda p : p .to_j_predicate (), predicates ))
405
- j_predicate = get_gateway ().jvm .PredicationUtil .buildOr (predicates )
406
- return Predicate (serialize_java_object (j_predicate ))
496
+ j_predicates = list (map (lambda p : p .to_j_predicate (), predicates ))
497
+ j_predicate = get_gateway ().jvm .PredicationUtil .buildOr (j_predicates )
498
+ return Predicate (PyNativePredicate ('or' , None , None , predicates ),
499
+ serialize_java_object (j_predicate ))
0 commit comments