From 698305b6f213fff67f76072e997518b5d32c13d8 Mon Sep 17 00:00:00 2001 From: yuzelin Date: Tue, 20 Aug 2024 12:03:03 +0800 Subject: [PATCH] fix --- java_based_implementation/api_impl.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/java_based_implementation/api_impl.py b/java_based_implementation/api_impl.py index eeccca3..1d04300 100644 --- a/java_based_implementation/api_impl.py +++ b/java_based_implementation/api_impl.py @@ -16,6 +16,8 @@ # limitations under the License. ################################################################################ +import itertools + from java_based_implementation.java_gateway import get_gateway from java_based_implementation.util.java_utils import to_j_catalog_context from paimon_python_api import (catalog, table, read_builder, table_scan, split, table_read, @@ -116,25 +118,26 @@ def __init__(self, j_table_read, j_row_type): self._j_table_read = j_table_read self._j_bytes_reader = get_gateway().jvm.InvocationUtil.createBytesReader( j_table_read, j_row_type) + self._arrow_schema = None def create_reader(self, split: Split): self._j_bytes_reader.setSplit(split.to_j_split()) - batches = [] - schema = None - for arrow_bytes in self._bytes_generator(): - stream_reader = RecordBatchStreamReader(BufferReader(arrow_bytes)) - if schema is None: - schema = stream_reader.schema - batches.extend(batch for batch in stream_reader) - return RecordBatchReader.from_batches(schema, batches) - - def _bytes_generator(self) -> Iterator[bytes]: + batch_iterator = self._batch_generator() + # to init arrow schema + first_batch = next(batch_iterator) + batches = itertools.chain((b for b in [first_batch]), batch_iterator) + return RecordBatchReader.from_batches(self._arrow_schema, batches) + + def _batch_generator(self) -> Iterator[RecordBatch]: while True: next_bytes = self._j_bytes_reader.next() if next_bytes is None: break else: - yield next_bytes + stream_reader = RecordBatchStreamReader(BufferReader(next_bytes)) + if self._arrow_schema is None: + self._arrow_schema = stream_reader.schema + yield from stream_reader def close(self): self._j_bytes_reader.close()