Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: BytesReader flush ArrowFormatWriter and pass schema #9

Merged
merged 1 commit into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 8 additions & 19 deletions java_based_implementation/api_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
# limitations under the License.
################################################################################

import itertools

from java_based_implementation.java_gateway import get_gateway
from java_based_implementation.util.java_utils import to_j_catalog_context, check_batch_write
from paimon_python_api import (catalog, table, read_builder, table_scan, split, table_read,
Expand Down Expand Up @@ -123,15 +121,15 @@ def __init__(self, j_table_read, j_row_type):

def create_reader(self, split: Split):
self._j_bytes_reader.setSplit(split.to_j_split())
batch_iterator = self._batch_generator()
# to init arrow schema
try:
first_batch = next(batch_iterator)
except StopIteration:
return self._empty_batch_reader()
# get schema
if self._arrow_schema is None:
schema_bytes = self._j_bytes_reader.serializeSchema()
schema_reader = RecordBatchStreamReader(BufferReader(schema_bytes))
self._arrow_schema = schema_reader.schema
schema_reader.close()

batches = itertools.chain((b for b in [first_batch]), batch_iterator)
return RecordBatchReader.from_batches(self._arrow_schema, batches)
batch_iterator = self._batch_generator()
return RecordBatchReader.from_batches(self._arrow_schema, batch_iterator)

def _batch_generator(self) -> Iterator[RecordBatch]:
while True:
Expand All @@ -140,17 +138,8 @@ def _batch_generator(self) -> Iterator[RecordBatch]:
break
else:
stream_reader = RecordBatchStreamReader(BufferReader(next_bytes))
if self._arrow_schema is None:
self._arrow_schema = stream_reader.schema
yield from stream_reader

def _empty_batch_reader(self):
import pyarrow as pa
schema = pa.schema([])
empty_batch = pa.RecordBatch.from_arrays([], schema=schema)
empty_reader = pa.RecordBatchReader.from_batches(schema, [empty_batch])
return empty_reader


class BatchWriteBuilder(write_builder.BatchWriteBuilder):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public class BytesReader {

public BytesReader(TableRead tableRead, RowType rowType) {
this.tableRead = tableRead;
this.arrowFormatWriter = new ArrowFormatWriter(rowType, DEFAULT_WRITE_BATCH_SIZE);
this.arrowFormatWriter = new ArrowFormatWriter(rowType, DEFAULT_WRITE_BATCH_SIZE, true);
}

public void setSplit(Split split) throws IOException {
Expand All @@ -56,6 +56,13 @@ public void setSplit(Split split) throws IOException {
nextRow();
}

public byte[] serializeSchema() {
VectorSchemaRoot vsr = arrowFormatWriter.getVectorSchemaRoot();
ByteArrayOutputStream out = new ByteArrayOutputStream();
ArrowUtils.serializeToIpc(vsr, out);
return out.toByteArray();
}

@Nullable
public byte[] next() throws Exception {
if (nextRow == null) {
Expand All @@ -68,6 +75,7 @@ public byte[] next() throws Exception {
rowCount++;
}

arrowFormatWriter.flush();
VectorSchemaRoot vsr = arrowFormatWriter.getVectorSchemaRoot();
vsr.setRowCount(rowCount);
ByteArrayOutputStream out = new ByteArrayOutputStream();
Expand Down
7 changes: 3 additions & 4 deletions java_based_implementation/tests/test_write_and_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ def testReadEmptyPkTable(self):
for split in splits
for batch in table_read.create_reader(split)
]
result = pd.concat(data_frames)
self.assertEqual(result.shape, (0, 0))
self.assertEqual(len(data_frames), 0)

def testWriteReadAppendTable(self):
create_simple_table(self.warehouse, 'default', 'simple_append_table', False)
Expand Down Expand Up @@ -135,8 +134,8 @@ def testWriteReadAppendTable(self):
]
result = pd.concat(data_frames)

# check data
pd.testing.assert_frame_equal(result, df)
# check data (ignore index)
pd.testing.assert_frame_equal(result.reset_index(drop=True), df.reset_index(drop=True))

def testWriteWrongSchema(self):
create_simple_table(self.warehouse, 'default', 'test_wrong_schema', False)
Expand Down
Loading