From 953f30fa9f696c8f077559c9c05f5c904e83f62b Mon Sep 17 00:00:00 2001 From: yuzelin <33053040+yuzelin@users.noreply.github.com> Date: Wed, 9 Oct 2024 17:12:36 +0800 Subject: [PATCH] Refactor reader and writer (#16) --- paimon_python_api/table_read.py | 13 +++- paimon_python_api/table_write.py | 13 +++- paimon_python_java/java_gateway.py | 2 +- .../org/apache/paimon/python/BytesWriter.java | 40 ------------ .../paimon/python/ParallelBytesReader.java | 20 +----- .../org/apache/paimon/python/SchemaUtil.java | 41 ++++++++++++ paimon_python_java/pypaimon.py | 63 +++++++++++++------ .../tests/test_write_and_read.py | 38 ++--------- 8 files changed, 114 insertions(+), 116 deletions(-) create mode 100644 paimon_python_java/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/SchemaUtil.java diff --git a/paimon_python_api/table_read.py b/paimon_python_api/table_read.py index 4fa12c1..24095b4 100644 --- a/paimon_python_api/table_read.py +++ b/paimon_python_api/table_read.py @@ -16,6 +16,7 @@ # limitations under the License. ################################################################################# +import pandas as pd import pyarrow as pa from abc import ABC, abstractmethod @@ -27,5 +28,13 @@ class TableRead(ABC): """To read data from data splits.""" @abstractmethod - def create_reader(self, splits: List[Split]) -> pa.RecordBatchReader: - """Return a reader containing batches of pyarrow format.""" + def to_arrow(self, splits: List[Split]) -> pa.Table: + """Read data from splits and converted to pyarrow.Table format.""" + + @abstractmethod + def to_arrow_batch_reader(self, splits: List[Split]) -> pa.RecordBatchReader: + """Read data from splits and converted to pyarrow.RecordBatchReader format.""" + + @abstractmethod + def to_pandas(self, splits: List[Split]) -> pd.DataFrame: + """Read data from splits and converted to pandas.DataFrame format.""" diff --git a/paimon_python_api/table_write.py b/paimon_python_api/table_write.py index 167ceeb..d1d39a7 100644 --- a/paimon_python_api/table_write.py +++ b/paimon_python_api/table_write.py @@ -16,6 +16,7 @@ # limitations under the License. ################################################################################# +import pandas as pd import pyarrow as pa from abc import ABC, abstractmethod @@ -27,8 +28,16 @@ class BatchTableWrite(ABC): """A table write for batch processing. Recommended for one-time committing.""" @abstractmethod - def write(self, record_batch: pa.RecordBatch): - """ Write a batch to the writer. */""" + def write_arrow(self, table: pa.Table): + """ Write an arrow table to the writer.""" + + @abstractmethod + def write_arrow_batch(self, record_batch: pa.RecordBatch): + """ Write an arrow record batch to the writer.""" + + @abstractmethod + def write_pandas(self, dataframe: pd.DataFrame): + """ Write a pandas dataframe to the writer.""" @abstractmethod def prepare_commit(self) -> List[CommitMessage]: diff --git a/paimon_python_java/java_gateway.py b/paimon_python_java/java_gateway.py index 372ae4a..f2b1621 100644 --- a/paimon_python_java/java_gateway.py +++ b/paimon_python_java/java_gateway.py @@ -107,7 +107,7 @@ def import_paimon_view(gateway): java_import(gateway.jvm, "org.apache.paimon.catalog.*") java_import(gateway.jvm, "org.apache.paimon.schema.Schema*") java_import(gateway.jvm, 'org.apache.paimon.types.*') - java_import(gateway.jvm, 'org.apache.paimon.python.InvocationUtil') + java_import(gateway.jvm, 'org.apache.paimon.python.*') java_import(gateway.jvm, "org.apache.paimon.data.*") diff --git a/paimon_python_java/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesWriter.java b/paimon_python_java/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesWriter.java index 7ad74c1..7cf6267 100644 --- a/paimon_python_java/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesWriter.java +++ b/paimon_python_java/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesWriter.java @@ -18,7 +18,6 @@ package org.apache.paimon.python; -import org.apache.paimon.arrow.ArrowUtils; import org.apache.paimon.arrow.reader.ArrowBatchReader; import org.apache.paimon.data.InternalRow; import org.apache.paimon.table.sink.TableWrite; @@ -28,12 +27,8 @@ import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.ipc.ArrowStreamReader; -import org.apache.arrow.vector.types.pojo.Field; import java.io.ByteArrayInputStream; -import java.util.List; -import java.util.Objects; -import java.util.stream.Collectors; /** Write Arrow bytes to Paimon. */ public class BytesWriter { @@ -41,30 +36,17 @@ public class BytesWriter { private final TableWrite tableWrite; private final ArrowBatchReader arrowBatchReader; private final BufferAllocator allocator; - private final List arrowFields; public BytesWriter(TableWrite tableWrite, RowType rowType) { this.tableWrite = tableWrite; this.arrowBatchReader = new ArrowBatchReader(rowType); this.allocator = new RootAllocator(); - arrowFields = - rowType.getFields().stream() - .map(f -> ArrowUtils.toArrowField(f.name(), f.type())) - .collect(Collectors.toList()); } public void write(byte[] bytes) throws Exception { ByteArrayInputStream bais = new ByteArrayInputStream(bytes); ArrowStreamReader arrowStreamReader = new ArrowStreamReader(bais, allocator); VectorSchemaRoot vsr = arrowStreamReader.getVectorSchemaRoot(); - if (!checkSchema(arrowFields, vsr.getSchema().getFields())) { - throw new RuntimeException( - String.format( - "Input schema isn't consistent with table schema.\n" - + "\tTable schema is: %s\n" - + "\tInput schema is: %s", - arrowFields, vsr.getSchema().getFields())); - } while (arrowStreamReader.loadNextBatch()) { Iterable rows = arrowBatchReader.readBatch(vsr); @@ -78,26 +60,4 @@ public void write(byte[] bytes) throws Exception { public void close() { allocator.close(); } - - private boolean checkSchema(List expectedFields, List actualFields) { - if (expectedFields.size() != actualFields.size()) { - return false; - } - - for (int i = 0; i < expectedFields.size(); i++) { - Field expectedField = expectedFields.get(i); - Field actualField = actualFields.get(i); - if (!checkField(expectedField, actualField) - || !checkSchema(expectedField.getChildren(), actualField.getChildren())) { - return false; - } - } - - return true; - } - - private boolean checkField(Field expected, Field actual) { - return Objects.equals(expected.getName(), actual.getName()) - && Objects.equals(expected.getType(), actual.getType()); - } } diff --git a/paimon_python_java/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/ParallelBytesReader.java b/paimon_python_java/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/ParallelBytesReader.java index 9f1b390..4c6ff00 100644 --- a/paimon_python_java/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/ParallelBytesReader.java +++ b/paimon_python_java/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/ParallelBytesReader.java @@ -18,7 +18,6 @@ package org.apache.paimon.python; -import org.apache.paimon.arrow.ArrowUtils; import org.apache.paimon.arrow.vector.ArrowFormatWriter; import org.apache.paimon.data.InternalRow; import org.apache.paimon.reader.RecordReader; @@ -30,11 +29,8 @@ import org.apache.paimon.shade.guava30.com.google.common.collect.Iterators; -import org.apache.arrow.vector.VectorSchemaRoot; - import javax.annotation.Nullable; -import java.io.ByteArrayOutputStream; import java.io.IOException; import java.util.ArrayDeque; import java.util.ArrayList; @@ -77,15 +73,6 @@ public void setSplits(List splits) { bytesIterator = randomlyExecute(getExecutor(), makeProcessor(), splits); } - public byte[] serializeSchema() { - ArrowFormatWriter arrowFormatWriter = newWriter(); - VectorSchemaRoot vsr = arrowFormatWriter.getVectorSchemaRoot(); - ByteArrayOutputStream out = new ByteArrayOutputStream(); - ArrowUtils.serializeToIpc(vsr, out); - arrowFormatWriter.close(); - return out.toByteArray(); - } - @Nullable public byte[] next() { if (bytesIterator.hasNext()) { @@ -110,7 +97,8 @@ private Function> makeProcessor() { RecordReaderIterator iterator = new RecordReaderIterator<>(recordReader); iterators.add(iterator); - ArrowFormatWriter arrowFormatWriter = newWriter(); + ArrowFormatWriter arrowFormatWriter = + new ArrowFormatWriter(rowType, DEFAULT_WRITE_BATCH_SIZE, true); arrowFormatWriters.add(arrowFormatWriter); return new RecordBytesIterator(iterator, arrowFormatWriter); } catch (IOException e) { @@ -164,8 +152,4 @@ private void closeResources() { } arrowFormatWriters.clear(); } - - private ArrowFormatWriter newWriter() { - return new ArrowFormatWriter(rowType, DEFAULT_WRITE_BATCH_SIZE, true); - } } diff --git a/paimon_python_java/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/SchemaUtil.java b/paimon_python_java/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/SchemaUtil.java new file mode 100644 index 0000000..64f6ce0 --- /dev/null +++ b/paimon_python_java/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/SchemaUtil.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.python; + +import org.apache.paimon.arrow.ArrowUtils; +import org.apache.paimon.types.RowType; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; + +import java.io.ByteArrayOutputStream; + +/** Util to get arrow schema from row type. */ +public class SchemaUtil { + public static byte[] getArrowSchema(RowType rowType) { + BufferAllocator allocator = new RootAllocator(); + VectorSchemaRoot emptyRoot = ArrowUtils.createVectorSchemaRoot(rowType, allocator, true); + ByteArrayOutputStream out = new ByteArrayOutputStream(); + ArrowUtils.serializeToIpc(emptyRoot, out); + emptyRoot.close(); + allocator.close(); + return out.toByteArray(); + } +} diff --git a/paimon_python_java/pypaimon.py b/paimon_python_java/pypaimon.py index 263491e..a007add 100644 --- a/paimon_python_java/pypaimon.py +++ b/paimon_python_java/pypaimon.py @@ -16,6 +16,7 @@ # limitations under the License. ################################################################################ +import pandas as pd import pyarrow as pa from paimon_python_java.java_gateway import get_gateway @@ -59,23 +60,30 @@ class Table(table.Table): def __init__(self, j_table, catalog_options: dict): self._j_table = j_table self._catalog_options = catalog_options + # init arrow schema + schema_bytes = get_gateway().jvm.SchemaUtil.getArrowSchema(j_table.rowType()) + schema_reader = pa.RecordBatchStreamReader(pa.BufferReader(schema_bytes)) + self._arrow_schema = schema_reader.schema + schema_reader.close() def new_read_builder(self) -> 'ReadBuilder': j_read_builder = get_gateway().jvm.InvocationUtil.getReadBuilder(self._j_table) - return ReadBuilder(j_read_builder, self._j_table.rowType(), self._catalog_options) + return ReadBuilder( + j_read_builder, self._j_table.rowType(), self._catalog_options, self._arrow_schema) def new_batch_write_builder(self) -> 'BatchWriteBuilder': java_utils.check_batch_write(self._j_table) j_batch_write_builder = get_gateway().jvm.InvocationUtil.getBatchWriteBuilder(self._j_table) - return BatchWriteBuilder(j_batch_write_builder, self._j_table.rowType()) + return BatchWriteBuilder(j_batch_write_builder, self._j_table.rowType(), self._arrow_schema) class ReadBuilder(read_builder.ReadBuilder): - def __init__(self, j_read_builder, j_row_type, catalog_options: dict): + def __init__(self, j_read_builder, j_row_type, catalog_options: dict, arrow_schema: pa.Schema): self._j_read_builder = j_read_builder self._j_row_type = j_row_type self._catalog_options = catalog_options + self._arrow_schema = arrow_schema def with_projection(self, projection: List[List[int]]) -> 'ReadBuilder': self._j_read_builder.withProjection(projection) @@ -91,7 +99,7 @@ def new_scan(self) -> 'TableScan': def new_read(self) -> 'TableRead': j_table_read = self._j_read_builder.newRead() - return TableRead(j_table_read, self._j_row_type, self._catalog_options) + return TableRead(j_table_read, self._j_row_type, self._catalog_options, self._arrow_schema) class TableScan(table_scan.TableScan): @@ -125,20 +133,27 @@ def to_j_split(self): class TableRead(table_read.TableRead): - def __init__(self, j_table_read, j_row_type, catalog_options): + def __init__(self, j_table_read, j_row_type, catalog_options, arrow_schema): self._j_table_read = j_table_read self._j_row_type = j_row_type self._catalog_options = catalog_options self._j_bytes_reader = None - self._arrow_schema = None + self._arrow_schema = arrow_schema - def create_reader(self, splits): + def to_arrow(self, splits): + record_batch_reader = self.to_arrow_batch_reader(splits) + return pa.Table.from_batches(record_batch_reader, schema=self._arrow_schema) + + def to_arrow_batch_reader(self, splits): self._init() j_splits = list(map(lambda s: s.to_j_split(), splits)) self._j_bytes_reader.setSplits(j_splits) batch_iterator = self._batch_generator() return pa.RecordBatchReader.from_batches(self._arrow_schema, batch_iterator) + def to_pandas(self, splits: List[Split]) -> pd.DataFrame: + return self.to_arrow(splits).to_pandas() + def _init(self): if self._j_bytes_reader is None: # get thread num @@ -153,12 +168,6 @@ def _init(self): self._j_bytes_reader = get_gateway().jvm.InvocationUtil.createParallelBytesReader( self._j_table_read, self._j_row_type, max_workers) - if self._arrow_schema is None: - schema_bytes = self._j_bytes_reader.serializeSchema() - schema_reader = pa.RecordBatchStreamReader(pa.BufferReader(schema_bytes)) - self._arrow_schema = schema_reader.schema - schema_reader.close() - def _batch_generator(self) -> Iterator[pa.RecordBatch]: while True: next_bytes = self._j_bytes_reader.next() @@ -171,9 +180,10 @@ def _batch_generator(self) -> Iterator[pa.RecordBatch]: class BatchWriteBuilder(write_builder.BatchWriteBuilder): - def __init__(self, j_batch_write_builder, j_row_type): + def __init__(self, j_batch_write_builder, j_row_type, arrow_schema: pa.Schema): self._j_batch_write_builder = j_batch_write_builder self._j_row_type = j_row_type + self._arrow_schema = arrow_schema def with_overwrite(self, static_partition: dict) -> 'BatchWriteBuilder': self._j_batch_write_builder.withOverwrite(static_partition) @@ -181,7 +191,7 @@ def with_overwrite(self, static_partition: dict) -> 'BatchWriteBuilder': def new_write(self) -> 'BatchTableWrite': j_batch_table_write = self._j_batch_write_builder.newWrite() - return BatchTableWrite(j_batch_table_write, self._j_row_type) + return BatchTableWrite(j_batch_table_write, self._j_row_type, self._arrow_schema) def new_commit(self) -> 'BatchTableCommit': j_batch_table_commit = self._j_batch_write_builder.newCommit() @@ -190,19 +200,32 @@ def new_commit(self) -> 'BatchTableCommit': class BatchTableWrite(table_write.BatchTableWrite): - def __init__(self, j_batch_table_write, j_row_type): + def __init__(self, j_batch_table_write, j_row_type, arrow_schema: pa.Schema): self._j_batch_table_write = j_batch_table_write self._j_bytes_writer = get_gateway().jvm.InvocationUtil.createBytesWriter( j_batch_table_write, j_row_type) - - def write(self, record_batch: pa.RecordBatch): + self._arrow_schema = arrow_schema + + def write_arrow(self, table): + for record_batch in table.to_reader(): + # TODO: can we use a reusable stream? + stream = pa.BufferOutputStream() + with pa.RecordBatchStreamWriter(stream, self._arrow_schema) as writer: + writer.write(record_batch) + arrow_bytes = stream.getvalue().to_pybytes() + self._j_bytes_writer.write(arrow_bytes) + + def write_arrow_batch(self, record_batch): stream = pa.BufferOutputStream() - with pa.RecordBatchStreamWriter(stream, record_batch.schema) as writer: + with pa.RecordBatchStreamWriter(stream, self._arrow_schema) as writer: writer.write(record_batch) - writer.close() arrow_bytes = stream.getvalue().to_pybytes() self._j_bytes_writer.write(arrow_bytes) + def write_pandas(self, dataframe: pd.DataFrame): + record_batch = pa.RecordBatch.from_pandas(dataframe, schema=self._arrow_schema) + self.write_arrow_batch(record_batch) + def prepare_commit(self) -> List['CommitMessage']: j_commit_messages = self._j_batch_table_write.prepareCommit() return list(map(lambda cm: CommitMessage(cm), j_commit_messages)) diff --git a/paimon_python_java/tests/test_write_and_read.py b/paimon_python_java/tests/test_write_and_read.py index 593fe7e..06cadd9 100644 --- a/paimon_python_java/tests/test_write_and_read.py +++ b/paimon_python_java/tests/test_write_and_read.py @@ -28,7 +28,6 @@ from paimon_python_java.java_gateway import get_gateway from paimon_python_java.tests import utils from paimon_python_java.util import java_utils -from py4j.protocol import Py4JJavaError from setup_utils import java_setuputils @@ -111,7 +110,7 @@ def testReadEmptyPkTable(self): data_frames = [ batch.to_pandas() for split in splits - for batch in table_read.create_reader([split]) + for batch in table_read.to_arrow_batch_reader([split]) ] self.assertEqual(len(data_frames), 0) @@ -133,7 +132,7 @@ def testWriteReadAppendTable(self): table_write = write_builder.new_write() table_commit = write_builder.new_commit() - table_write.write(record_batch) + table_write.write_arrow_batch(record_batch) commit_messages = table_write.prepare_commit() table_commit.commit(commit_messages) @@ -149,7 +148,7 @@ def testWriteReadAppendTable(self): data_frames = [ batch.to_pandas() for split in splits - for batch in table_read.create_reader([split]) + for batch in table_read.to_arrow_batch_reader([split]) ] result = pd.concat(data_frames) @@ -159,33 +158,6 @@ def testWriteReadAppendTable(self): pd.testing.assert_frame_equal( result.reset_index(drop=True), expected.reset_index(drop=True)) - def testWriteWrongSchema(self): - schema = Schema(self.simple_pa_schema) - self.catalog.create_table('default.test_wrong_schema', schema, False) - table = self.catalog.get_table('default.test_wrong_schema') - - data = { - 'f0': [1, 2, 3], - 'f1': ['a', 'b', 'c'], - } - df = pd.DataFrame(data) - schema = pa.schema([ - ('f0', pa.int64()), - ('f1', pa.string()) - ]) - record_batch = pa.RecordBatch.from_pandas(df, schema=schema) - - write_builder = table.new_batch_write_builder() - table_write = write_builder.new_write() - - with self.assertRaises(Py4JJavaError) as e: - table_write.write(record_batch) - self.assertEqual( - str(e.exception.java_exception), - '''java.lang.RuntimeException: Input schema isn't consistent with table schema. -\tTable schema is: [f0: Int(32, true), f1: Utf8] -\tInput schema is: [f0: Int(64, true), f1: Utf8]''') - def testCannotWriteDynamicBucketTable(self): schema = Schema(self.simple_pa_schema, primary_keys=['f0']) self.catalog.create_table('default.test_dynamic_bucket', schema, False) @@ -225,7 +197,7 @@ def testParallelRead(self): table_write = write_builder.new_write() table_commit = write_builder.new_commit() - table_write.write(record_batch) + table_write.write_arrow_batch(record_batch) commit_messages = table_write.prepare_commit() table_commit.commit(commit_messages) @@ -237,7 +209,7 @@ def testParallelRead(self): data_frames = [ batch.to_pandas() - for batch in table_read.create_reader(splits) + for batch in table_read.to_arrow_batch_reader(splits) ] result = pd.concat(data_frames)