Skip to content

Commit

Permalink
Refactor reader and writer (#16)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuzelin authored Oct 9, 2024
1 parent 6581f65 commit 953f30f
Show file tree
Hide file tree
Showing 8 changed files with 114 additions and 116 deletions.
13 changes: 11 additions & 2 deletions paimon_python_api/table_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# limitations under the License.
#################################################################################

import pandas as pd
import pyarrow as pa

from abc import ABC, abstractmethod
Expand All @@ -27,5 +28,13 @@ class TableRead(ABC):
"""To read data from data splits."""

@abstractmethod
def create_reader(self, splits: List[Split]) -> pa.RecordBatchReader:
"""Return a reader containing batches of pyarrow format."""
def to_arrow(self, splits: List[Split]) -> pa.Table:
"""Read data from splits and converted to pyarrow.Table format."""

@abstractmethod
def to_arrow_batch_reader(self, splits: List[Split]) -> pa.RecordBatchReader:
"""Read data from splits and converted to pyarrow.RecordBatchReader format."""

@abstractmethod
def to_pandas(self, splits: List[Split]) -> pd.DataFrame:
"""Read data from splits and converted to pandas.DataFrame format."""
13 changes: 11 additions & 2 deletions paimon_python_api/table_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# limitations under the License.
#################################################################################

import pandas as pd
import pyarrow as pa

from abc import ABC, abstractmethod
Expand All @@ -27,8 +28,16 @@ class BatchTableWrite(ABC):
"""A table write for batch processing. Recommended for one-time committing."""

@abstractmethod
def write(self, record_batch: pa.RecordBatch):
""" Write a batch to the writer. */"""
def write_arrow(self, table: pa.Table):
""" Write an arrow table to the writer."""

@abstractmethod
def write_arrow_batch(self, record_batch: pa.RecordBatch):
""" Write an arrow record batch to the writer."""

@abstractmethod
def write_pandas(self, dataframe: pd.DataFrame):
""" Write a pandas dataframe to the writer."""

@abstractmethod
def prepare_commit(self) -> List[CommitMessage]:
Expand Down
2 changes: 1 addition & 1 deletion paimon_python_java/java_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def import_paimon_view(gateway):
java_import(gateway.jvm, "org.apache.paimon.catalog.*")
java_import(gateway.jvm, "org.apache.paimon.schema.Schema*")
java_import(gateway.jvm, 'org.apache.paimon.types.*')
java_import(gateway.jvm, 'org.apache.paimon.python.InvocationUtil')
java_import(gateway.jvm, 'org.apache.paimon.python.*')
java_import(gateway.jvm, "org.apache.paimon.data.*")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

package org.apache.paimon.python;

import org.apache.paimon.arrow.ArrowUtils;
import org.apache.paimon.arrow.reader.ArrowBatchReader;
import org.apache.paimon.data.InternalRow;
import org.apache.paimon.table.sink.TableWrite;
Expand All @@ -28,43 +27,26 @@
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.apache.arrow.vector.types.pojo.Field;

import java.io.ByteArrayInputStream;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

/** Write Arrow bytes to Paimon. */
public class BytesWriter {

private final TableWrite tableWrite;
private final ArrowBatchReader arrowBatchReader;
private final BufferAllocator allocator;
private final List<Field> arrowFields;

public BytesWriter(TableWrite tableWrite, RowType rowType) {
this.tableWrite = tableWrite;
this.arrowBatchReader = new ArrowBatchReader(rowType);
this.allocator = new RootAllocator();
arrowFields =
rowType.getFields().stream()
.map(f -> ArrowUtils.toArrowField(f.name(), f.type()))
.collect(Collectors.toList());
}

public void write(byte[] bytes) throws Exception {
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
ArrowStreamReader arrowStreamReader = new ArrowStreamReader(bais, allocator);
VectorSchemaRoot vsr = arrowStreamReader.getVectorSchemaRoot();
if (!checkSchema(arrowFields, vsr.getSchema().getFields())) {
throw new RuntimeException(
String.format(
"Input schema isn't consistent with table schema.\n"
+ "\tTable schema is: %s\n"
+ "\tInput schema is: %s",
arrowFields, vsr.getSchema().getFields()));
}

while (arrowStreamReader.loadNextBatch()) {
Iterable<InternalRow> rows = arrowBatchReader.readBatch(vsr);
Expand All @@ -78,26 +60,4 @@ public void write(byte[] bytes) throws Exception {
public void close() {
allocator.close();
}

private boolean checkSchema(List<Field> expectedFields, List<Field> actualFields) {
if (expectedFields.size() != actualFields.size()) {
return false;
}

for (int i = 0; i < expectedFields.size(); i++) {
Field expectedField = expectedFields.get(i);
Field actualField = actualFields.get(i);
if (!checkField(expectedField, actualField)
|| !checkSchema(expectedField.getChildren(), actualField.getChildren())) {
return false;
}
}

return true;
}

private boolean checkField(Field expected, Field actual) {
return Objects.equals(expected.getName(), actual.getName())
&& Objects.equals(expected.getType(), actual.getType());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

package org.apache.paimon.python;

import org.apache.paimon.arrow.ArrowUtils;
import org.apache.paimon.arrow.vector.ArrowFormatWriter;
import org.apache.paimon.data.InternalRow;
import org.apache.paimon.reader.RecordReader;
Expand All @@ -30,11 +29,8 @@

import org.apache.paimon.shade.guava30.com.google.common.collect.Iterators;

import org.apache.arrow.vector.VectorSchemaRoot;

import javax.annotation.Nullable;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.ArrayDeque;
import java.util.ArrayList;
Expand Down Expand Up @@ -77,15 +73,6 @@ public void setSplits(List<Split> splits) {
bytesIterator = randomlyExecute(getExecutor(), makeProcessor(), splits);
}

public byte[] serializeSchema() {
ArrowFormatWriter arrowFormatWriter = newWriter();
VectorSchemaRoot vsr = arrowFormatWriter.getVectorSchemaRoot();
ByteArrayOutputStream out = new ByteArrayOutputStream();
ArrowUtils.serializeToIpc(vsr, out);
arrowFormatWriter.close();
return out.toByteArray();
}

@Nullable
public byte[] next() {
if (bytesIterator.hasNext()) {
Expand All @@ -110,7 +97,8 @@ private Function<Split, Iterator<byte[]>> makeProcessor() {
RecordReaderIterator<InternalRow> iterator =
new RecordReaderIterator<>(recordReader);
iterators.add(iterator);
ArrowFormatWriter arrowFormatWriter = newWriter();
ArrowFormatWriter arrowFormatWriter =
new ArrowFormatWriter(rowType, DEFAULT_WRITE_BATCH_SIZE, true);
arrowFormatWriters.add(arrowFormatWriter);
return new RecordBytesIterator(iterator, arrowFormatWriter);
} catch (IOException e) {
Expand Down Expand Up @@ -164,8 +152,4 @@ private void closeResources() {
}
arrowFormatWriters.clear();
}

private ArrowFormatWriter newWriter() {
return new ArrowFormatWriter(rowType, DEFAULT_WRITE_BATCH_SIZE, true);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.paimon.python;

import org.apache.paimon.arrow.ArrowUtils;
import org.apache.paimon.types.RowType;

import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.VectorSchemaRoot;

import java.io.ByteArrayOutputStream;

/** Util to get arrow schema from row type. */
public class SchemaUtil {
public static byte[] getArrowSchema(RowType rowType) {
BufferAllocator allocator = new RootAllocator();
VectorSchemaRoot emptyRoot = ArrowUtils.createVectorSchemaRoot(rowType, allocator, true);
ByteArrayOutputStream out = new ByteArrayOutputStream();
ArrowUtils.serializeToIpc(emptyRoot, out);
emptyRoot.close();
allocator.close();
return out.toByteArray();
}
}
63 changes: 43 additions & 20 deletions paimon_python_java/pypaimon.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# limitations under the License.
################################################################################

import pandas as pd
import pyarrow as pa

from paimon_python_java.java_gateway import get_gateway
Expand Down Expand Up @@ -59,23 +60,30 @@ class Table(table.Table):
def __init__(self, j_table, catalog_options: dict):
self._j_table = j_table
self._catalog_options = catalog_options
# init arrow schema
schema_bytes = get_gateway().jvm.SchemaUtil.getArrowSchema(j_table.rowType())
schema_reader = pa.RecordBatchStreamReader(pa.BufferReader(schema_bytes))
self._arrow_schema = schema_reader.schema
schema_reader.close()

def new_read_builder(self) -> 'ReadBuilder':
j_read_builder = get_gateway().jvm.InvocationUtil.getReadBuilder(self._j_table)
return ReadBuilder(j_read_builder, self._j_table.rowType(), self._catalog_options)
return ReadBuilder(
j_read_builder, self._j_table.rowType(), self._catalog_options, self._arrow_schema)

def new_batch_write_builder(self) -> 'BatchWriteBuilder':
java_utils.check_batch_write(self._j_table)
j_batch_write_builder = get_gateway().jvm.InvocationUtil.getBatchWriteBuilder(self._j_table)
return BatchWriteBuilder(j_batch_write_builder, self._j_table.rowType())
return BatchWriteBuilder(j_batch_write_builder, self._j_table.rowType(), self._arrow_schema)


class ReadBuilder(read_builder.ReadBuilder):

def __init__(self, j_read_builder, j_row_type, catalog_options: dict):
def __init__(self, j_read_builder, j_row_type, catalog_options: dict, arrow_schema: pa.Schema):
self._j_read_builder = j_read_builder
self._j_row_type = j_row_type
self._catalog_options = catalog_options
self._arrow_schema = arrow_schema

def with_projection(self, projection: List[List[int]]) -> 'ReadBuilder':
self._j_read_builder.withProjection(projection)
Expand All @@ -91,7 +99,7 @@ def new_scan(self) -> 'TableScan':

def new_read(self) -> 'TableRead':
j_table_read = self._j_read_builder.newRead()
return TableRead(j_table_read, self._j_row_type, self._catalog_options)
return TableRead(j_table_read, self._j_row_type, self._catalog_options, self._arrow_schema)


class TableScan(table_scan.TableScan):
Expand Down Expand Up @@ -125,20 +133,27 @@ def to_j_split(self):

class TableRead(table_read.TableRead):

def __init__(self, j_table_read, j_row_type, catalog_options):
def __init__(self, j_table_read, j_row_type, catalog_options, arrow_schema):
self._j_table_read = j_table_read
self._j_row_type = j_row_type
self._catalog_options = catalog_options
self._j_bytes_reader = None
self._arrow_schema = None
self._arrow_schema = arrow_schema

def create_reader(self, splits):
def to_arrow(self, splits):
record_batch_reader = self.to_arrow_batch_reader(splits)
return pa.Table.from_batches(record_batch_reader, schema=self._arrow_schema)

def to_arrow_batch_reader(self, splits):
self._init()
j_splits = list(map(lambda s: s.to_j_split(), splits))
self._j_bytes_reader.setSplits(j_splits)
batch_iterator = self._batch_generator()
return pa.RecordBatchReader.from_batches(self._arrow_schema, batch_iterator)

def to_pandas(self, splits: List[Split]) -> pd.DataFrame:
return self.to_arrow(splits).to_pandas()

def _init(self):
if self._j_bytes_reader is None:
# get thread num
Expand All @@ -153,12 +168,6 @@ def _init(self):
self._j_bytes_reader = get_gateway().jvm.InvocationUtil.createParallelBytesReader(
self._j_table_read, self._j_row_type, max_workers)

if self._arrow_schema is None:
schema_bytes = self._j_bytes_reader.serializeSchema()
schema_reader = pa.RecordBatchStreamReader(pa.BufferReader(schema_bytes))
self._arrow_schema = schema_reader.schema
schema_reader.close()

def _batch_generator(self) -> Iterator[pa.RecordBatch]:
while True:
next_bytes = self._j_bytes_reader.next()
Expand All @@ -171,17 +180,18 @@ def _batch_generator(self) -> Iterator[pa.RecordBatch]:

class BatchWriteBuilder(write_builder.BatchWriteBuilder):

def __init__(self, j_batch_write_builder, j_row_type):
def __init__(self, j_batch_write_builder, j_row_type, arrow_schema: pa.Schema):
self._j_batch_write_builder = j_batch_write_builder
self._j_row_type = j_row_type
self._arrow_schema = arrow_schema

def with_overwrite(self, static_partition: dict) -> 'BatchWriteBuilder':
self._j_batch_write_builder.withOverwrite(static_partition)
return self

def new_write(self) -> 'BatchTableWrite':
j_batch_table_write = self._j_batch_write_builder.newWrite()
return BatchTableWrite(j_batch_table_write, self._j_row_type)
return BatchTableWrite(j_batch_table_write, self._j_row_type, self._arrow_schema)

def new_commit(self) -> 'BatchTableCommit':
j_batch_table_commit = self._j_batch_write_builder.newCommit()
Expand All @@ -190,19 +200,32 @@ def new_commit(self) -> 'BatchTableCommit':

class BatchTableWrite(table_write.BatchTableWrite):

def __init__(self, j_batch_table_write, j_row_type):
def __init__(self, j_batch_table_write, j_row_type, arrow_schema: pa.Schema):
self._j_batch_table_write = j_batch_table_write
self._j_bytes_writer = get_gateway().jvm.InvocationUtil.createBytesWriter(
j_batch_table_write, j_row_type)

def write(self, record_batch: pa.RecordBatch):
self._arrow_schema = arrow_schema

def write_arrow(self, table):
for record_batch in table.to_reader():
# TODO: can we use a reusable stream?
stream = pa.BufferOutputStream()
with pa.RecordBatchStreamWriter(stream, self._arrow_schema) as writer:
writer.write(record_batch)
arrow_bytes = stream.getvalue().to_pybytes()
self._j_bytes_writer.write(arrow_bytes)

def write_arrow_batch(self, record_batch):
stream = pa.BufferOutputStream()
with pa.RecordBatchStreamWriter(stream, record_batch.schema) as writer:
with pa.RecordBatchStreamWriter(stream, self._arrow_schema) as writer:
writer.write(record_batch)
writer.close()
arrow_bytes = stream.getvalue().to_pybytes()
self._j_bytes_writer.write(arrow_bytes)

def write_pandas(self, dataframe: pd.DataFrame):
record_batch = pa.RecordBatch.from_pandas(dataframe, schema=self._arrow_schema)
self.write_arrow_batch(record_batch)

def prepare_commit(self) -> List['CommitMessage']:
j_commit_messages = self._j_batch_table_write.prepareCommit()
return list(map(lambda cm: CommitMessage(cm), j_commit_messages))
Expand Down
Loading

0 comments on commit 953f30f

Please sign in to comment.