Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add table bucket mode check and data schema check for writing #8

Merged
merged 2 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion java_based_implementation/api_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import itertools

from java_based_implementation.java_gateway import get_gateway
from java_based_implementation.util.java_utils import to_j_catalog_context
from java_based_implementation.util.java_utils import to_j_catalog_context, check_batch_write
from paimon_python_api import (catalog, table, read_builder, table_scan, split, table_read,
write_builder, table_write, commit_message, table_commit)
from pyarrow import (RecordBatch, BufferOutputStream, RecordBatchStreamWriter,
Expand Down Expand Up @@ -56,6 +56,7 @@ def new_read_builder(self) -> 'ReadBuilder':
return ReadBuilder(j_read_builder, self._j_table.rowType())

def new_batch_write_builder(self) -> 'BatchWriteBuilder':
check_batch_write(self._j_table)
j_batch_write_builder = get_gateway().jvm.InvocationUtil.getBatchWriteBuilder(self._j_table)
return BatchWriteBuilder(j_batch_write_builder, self._j_table.rowType())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

package org.apache.paimon.python;

import org.apache.paimon.arrow.ArrowUtils;
import org.apache.paimon.arrow.reader.ArrowBatchReader;
import org.apache.paimon.data.InternalRow;
import org.apache.paimon.table.sink.TableWrite;
Expand All @@ -27,26 +28,44 @@
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.apache.arrow.vector.types.pojo.Field;

import java.io.ByteArrayInputStream;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

/** Write Arrow bytes to Paimon. */
public class BytesWriter {

private final TableWrite tableWrite;
private final ArrowBatchReader arrowBatchReader;
private final BufferAllocator allocator;
private final List<Field> arrowFields;

public BytesWriter(TableWrite tableWrite, RowType rowType) {
this.tableWrite = tableWrite;
this.arrowBatchReader = new ArrowBatchReader(rowType);
this.allocator = new RootAllocator();
arrowFields =
rowType.getFields().stream()
.map(f -> ArrowUtils.toArrowField(f.name(), f.type()))
.collect(Collectors.toList());
}

public void write(byte[] bytes) throws Exception {
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
ArrowStreamReader arrowStreamReader = new ArrowStreamReader(bais, allocator);
VectorSchemaRoot vsr = arrowStreamReader.getVectorSchemaRoot();
if (!checkSchema(arrowFields, vsr.getSchema().getFields())) {
throw new RuntimeException(
String.format(
"Input schema isn't consistent with table schema.\n"
+ "\tTable schema is: %s\n"
+ "\tInput schema is: %s",
arrowFields, vsr.getSchema().getFields()));
}

while (arrowStreamReader.loadNextBatch()) {
Iterable<InternalRow> rows = arrowBatchReader.readBatch(vsr);
for (InternalRow row : rows) {
Expand All @@ -59,4 +78,26 @@ public void write(byte[] bytes) throws Exception {
public void close() {
allocator.close();
}

private boolean checkSchema(List<Field> expectedFields, List<Field> actualFields) {
if (expectedFields.size() != actualFields.size()) {
return false;
}

for (int i = 0; i < expectedFields.size(); i++) {
Field expectedField = expectedFields.get(i);
Field actualField = actualFields.get(i);
if (!checkField(expectedField, actualField)
|| !checkSchema(expectedField.getChildren(), actualField.getChildren())) {
return false;
}
}

return true;
}

private boolean checkField(Field expected, Field actual) {
return Objects.equals(expected.getName(), actual.getName())
&& Objects.equals(expected.getType(), actual.getType());
}
}
47 changes: 47 additions & 0 deletions java_based_implementation/tests/test_write_and_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from java_based_implementation.java_gateway import get_gateway
from java_based_implementation.tests.utils import set_bridge_jar, create_simple_table
from java_based_implementation.util import constants, java_utils
from py4j.protocol import Py4JJavaError


class TableWriteReadTest(unittest.TestCase):
Expand Down Expand Up @@ -136,3 +137,49 @@ def testWriteReadAppendTable(self):

# check data
pd.testing.assert_frame_equal(result, df)

def testWriteWrongSchema(self):
create_simple_table(self.warehouse, 'default', 'test_wrong_schema', False)

catalog = Catalog.create({'warehouse': self.warehouse})
table = catalog.get_table('default.test_wrong_schema')

data = {
'f0': [1, 2, 3],
'f1': ['a', 'b', 'c'],
}
df = pd.DataFrame(data)
schema = pa.schema([
('f0', pa.int64()),
('f1', pa.string())
])
record_batch = pa.RecordBatch.from_pandas(df, schema)

write_builder = table.new_batch_write_builder()
table_write = write_builder.new_write()

with self.assertRaises(Py4JJavaError) as e:
table_write.write(record_batch)
self.assertEqual(
str(e.exception.java_exception),
'''java.lang.RuntimeException: Input schema isn't consistent with table schema.
\tTable schema is: [f0: Int(32, true), f1: Utf8]
\tInput schema is: [f0: Int(64, true), f1: Utf8]''')

def testCannotWriteDynamicBucketTable(self):
create_simple_table(
self.warehouse,
'default',
'test_dynamic_bucket',
True,
{'bucket': '-1'}
)

catalog = Catalog.create({'warehouse': self.warehouse})
table = catalog.get_table('default.test_dynamic_bucket')

with self.assertRaises(TypeError) as e:
table.new_batch_write_builder()
self.assertEqual(
str(e.exception),
"Doesn't support writing dynamic bucket or cross partition table.")
11 changes: 8 additions & 3 deletions java_based_implementation/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,13 @@ def set_bridge_jar() -> str:
return os.path.join(temp_dir, jar_name)


def create_simple_table(warehouse, database, table_name, has_pk):
def create_simple_table(warehouse, database, table_name, has_pk, options=None):
if options is None:
options = {
'bucket': '1',
'bucket-key': 'f0'
}

gateway = get_gateway()

j_catalog_context = to_j_catalog_context({'warehouse': warehouse})
Expand All @@ -55,8 +61,7 @@ def create_simple_table(warehouse, database, table_name, has_pk):
gateway.jvm.Schema.newBuilder()
.column('f0', gateway.jvm.DataTypes.INT())
.column('f1', gateway.jvm.DataTypes.STRING())
.option('bucket', '1')
.option('bucket-key', 'f0')
.options(options)
)
if has_pk:
j_schema_builder.primaryKey(['f0'])
Expand Down
8 changes: 8 additions & 0 deletions java_based_implementation/util/java_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,11 @@ def to_j_catalog_context(catalog_context: dict):
gateway = get_gateway()
j_options = gateway.jvm.Options(catalog_context)
return gateway.jvm.CatalogContext.create(j_options)


def check_batch_write(j_table):
gateway = get_gateway()
bucket_mode = j_table.bucketMode()
if bucket_mode == gateway.jvm.BucketMode.HASH_DYNAMIC \
or bucket_mode == gateway.jvm.BucketMode.CROSS_PARTITION:
raise TypeError("Doesn't support writing dynamic bucket or cross partition table.")
Loading