16
16
# limitations under the License.
17
17
################################################################################
18
18
19
+ import itertools
20
+
19
21
from java_based_implementation .java_gateway import get_gateway
20
22
from java_based_implementation .util .java_utils import to_j_catalog_context
21
23
from paimon_python_api import (catalog , table , read_builder , table_scan , split , table_read ,
22
24
write_builder , table_write , commit_message , table_commit )
23
- from pyarrow import RecordBatchReader , RecordBatch
24
- from typing import List
25
+ from pyarrow import (RecordBatch , BufferOutputStream , RecordBatchStreamWriter ,
26
+ RecordBatchStreamReader , BufferReader , RecordBatchReader )
27
+ from typing import List , Iterator
25
28
26
29
27
30
class Catalog (catalog .Catalog ):
@@ -49,18 +52,19 @@ def __init__(self, j_table):
49
52
self ._j_table = j_table
50
53
51
54
def new_read_builder (self ) -> 'ReadBuilder' :
52
- j_read_builder = self . _j_table . newReadBuilder ( )
53
- return ReadBuilder (j_read_builder )
55
+ j_read_builder = get_gateway (). jvm . InvocationUtil . getReadBuilder ( self . _j_table )
56
+ return ReadBuilder (j_read_builder , self . _j_table . rowType () )
54
57
55
58
def new_batch_write_builder (self ) -> 'BatchWriteBuilder' :
56
- j_batch_write_builder = self . _j_table . newBatchWriteBuilder ( )
57
- return BatchWriteBuilder (j_batch_write_builder )
59
+ j_batch_write_builder = get_gateway (). jvm . InvocationUtil . getBatchWriteBuilder ( self . _j_table )
60
+ return BatchWriteBuilder (j_batch_write_builder , self . _j_table . rowType () )
58
61
59
62
60
63
class ReadBuilder (read_builder .ReadBuilder ):
61
64
62
- def __init__ (self , j_read_builder ):
65
+ def __init__ (self , j_read_builder , j_row_type ):
63
66
self ._j_read_builder = j_read_builder
67
+ self ._j_row_type = j_row_type
64
68
65
69
def with_projection (self , projection : List [List [int ]]) -> 'ReadBuilder' :
66
70
self ._j_read_builder .withProjection (projection )
@@ -75,8 +79,8 @@ def new_scan(self) -> 'TableScan':
75
79
return TableScan (j_table_scan )
76
80
77
81
def new_read (self ) -> 'TableRead' :
78
- # TODO
79
- pass
82
+ j_table_read = self . _j_read_builder . newRead ()
83
+ return TableRead ( j_table_read , self . _j_row_type )
80
84
81
85
82
86
class TableScan (table_scan .TableScan ):
@@ -110,23 +114,56 @@ def to_j_split(self):
110
114
111
115
class TableRead (table_read .TableRead ):
112
116
113
- def create_reader (self , split : Split ) -> RecordBatchReader :
114
- # TODO
115
- pass
117
+ def __init__ (self , j_table_read , j_row_type ):
118
+ self ._j_table_read = j_table_read
119
+ self ._j_bytes_reader = get_gateway ().jvm .InvocationUtil .createBytesReader (
120
+ j_table_read , j_row_type )
121
+ self ._arrow_schema = None
122
+
123
+ def create_reader (self , split : Split ):
124
+ self ._j_bytes_reader .setSplit (split .to_j_split ())
125
+ batch_iterator = self ._batch_generator ()
126
+ # to init arrow schema
127
+ try :
128
+ first_batch = next (batch_iterator )
129
+ except StopIteration :
130
+ return self ._empty_batch_reader ()
131
+
132
+ batches = itertools .chain ((b for b in [first_batch ]), batch_iterator )
133
+ return RecordBatchReader .from_batches (self ._arrow_schema , batches )
134
+
135
+ def _batch_generator (self ) -> Iterator [RecordBatch ]:
136
+ while True :
137
+ next_bytes = self ._j_bytes_reader .next ()
138
+ if next_bytes is None :
139
+ break
140
+ else :
141
+ stream_reader = RecordBatchStreamReader (BufferReader (next_bytes ))
142
+ if self ._arrow_schema is None :
143
+ self ._arrow_schema = stream_reader .schema
144
+ yield from stream_reader
145
+
146
+ def _empty_batch_reader (self ):
147
+ import pyarrow as pa
148
+ schema = pa .schema ([])
149
+ empty_batch = pa .RecordBatch .from_arrays ([], schema = schema )
150
+ empty_reader = pa .RecordBatchReader .from_batches (schema , [empty_batch ])
151
+ return empty_reader
116
152
117
153
118
154
class BatchWriteBuilder (write_builder .BatchWriteBuilder ):
119
155
120
- def __init__ (self , j_batch_write_builder ):
156
+ def __init__ (self , j_batch_write_builder , j_row_type ):
121
157
self ._j_batch_write_builder = j_batch_write_builder
158
+ self ._j_row_type = j_row_type
122
159
123
160
def with_overwrite (self , static_partition : dict ) -> 'BatchWriteBuilder' :
124
161
self ._j_batch_write_builder .withOverwrite (static_partition )
125
162
return self
126
163
127
164
def new_write (self ) -> 'BatchTableWrite' :
128
165
j_batch_table_write = self ._j_batch_write_builder .newWrite ()
129
- return BatchTableWrite (j_batch_table_write )
166
+ return BatchTableWrite (j_batch_table_write , self . _j_row_type )
130
167
131
168
def new_commit (self ) -> 'BatchTableCommit' :
132
169
j_batch_table_commit = self ._j_batch_write_builder .newCommit ()
@@ -135,17 +172,27 @@ def new_commit(self) -> 'BatchTableCommit':
135
172
136
173
class BatchTableWrite (table_write .BatchTableWrite ):
137
174
138
- def __init__ (self , j_batch_table_write ):
175
+ def __init__ (self , j_batch_table_write , j_row_type ):
139
176
self ._j_batch_table_write = j_batch_table_write
177
+ self ._j_bytes_writer = get_gateway ().jvm .InvocationUtil .createBytesWriter (
178
+ j_batch_table_write , j_row_type )
140
179
141
180
def write (self , record_batch : RecordBatch ):
142
- # TODO
143
- pass
181
+ stream = BufferOutputStream ()
182
+ with RecordBatchStreamWriter (stream , record_batch .schema ) as writer :
183
+ writer .write (record_batch )
184
+ writer .close ()
185
+ arrow_bytes = stream .getvalue ().to_pybytes ()
186
+ self ._j_bytes_writer .write (arrow_bytes )
144
187
145
188
def prepare_commit (self ) -> List ['CommitMessage' ]:
146
189
j_commit_messages = self ._j_batch_table_write .prepareCommit ()
147
190
return list (map (lambda cm : CommitMessage (cm ), j_commit_messages ))
148
191
192
+ def close (self ):
193
+ self ._j_batch_table_write .close ()
194
+ self ._j_bytes_writer .close ()
195
+
149
196
150
197
class CommitMessage (commit_message .CommitMessage ):
151
198
@@ -164,3 +211,6 @@ def __init__(self, j_batch_table_commit):
164
211
def commit (self , commit_messages : List [CommitMessage ]):
165
212
j_commit_messages = list (map (lambda cm : cm .to_j_commit_message (), commit_messages ))
166
213
self ._j_batch_table_commit .commit (j_commit_messages )
214
+
215
+ def close (self ):
216
+ self ._j_batch_table_commit .close ()
0 commit comments