Skip to content

Commit c505c28

Browse files
authored
Add table bucket mode check and data schema check for writing (#8)
1 parent b95ee36 commit c505c28

File tree

5 files changed

+106
-4
lines changed

5 files changed

+106
-4
lines changed

java_based_implementation/api_impl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import itertools
2020

2121
from java_based_implementation.java_gateway import get_gateway
22-
from java_based_implementation.util.java_utils import to_j_catalog_context
22+
from java_based_implementation.util.java_utils import to_j_catalog_context, check_batch_write
2323
from paimon_python_api import (catalog, table, read_builder, table_scan, split, table_read,
2424
write_builder, table_write, commit_message, table_commit)
2525
from pyarrow import (RecordBatch, BufferOutputStream, RecordBatchStreamWriter,
@@ -56,6 +56,7 @@ def new_read_builder(self) -> 'ReadBuilder':
5656
return ReadBuilder(j_read_builder, self._j_table.rowType())
5757

5858
def new_batch_write_builder(self) -> 'BatchWriteBuilder':
59+
check_batch_write(self._j_table)
5960
j_batch_write_builder = get_gateway().jvm.InvocationUtil.getBatchWriteBuilder(self._j_table)
6061
return BatchWriteBuilder(j_batch_write_builder, self._j_table.rowType())
6162

java_based_implementation/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesWriter.java

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
package org.apache.paimon.python;
2020

21+
import org.apache.paimon.arrow.ArrowUtils;
2122
import org.apache.paimon.arrow.reader.ArrowBatchReader;
2223
import org.apache.paimon.data.InternalRow;
2324
import org.apache.paimon.table.sink.TableWrite;
@@ -27,26 +28,44 @@
2728
import org.apache.arrow.memory.RootAllocator;
2829
import org.apache.arrow.vector.VectorSchemaRoot;
2930
import org.apache.arrow.vector.ipc.ArrowStreamReader;
31+
import org.apache.arrow.vector.types.pojo.Field;
3032

3133
import java.io.ByteArrayInputStream;
34+
import java.util.List;
35+
import java.util.Objects;
36+
import java.util.stream.Collectors;
3237

3338
/** Write Arrow bytes to Paimon. */
3439
public class BytesWriter {
3540

3641
private final TableWrite tableWrite;
3742
private final ArrowBatchReader arrowBatchReader;
3843
private final BufferAllocator allocator;
44+
private final List<Field> arrowFields;
3945

4046
public BytesWriter(TableWrite tableWrite, RowType rowType) {
4147
this.tableWrite = tableWrite;
4248
this.arrowBatchReader = new ArrowBatchReader(rowType);
4349
this.allocator = new RootAllocator();
50+
arrowFields =
51+
rowType.getFields().stream()
52+
.map(f -> ArrowUtils.toArrowField(f.name(), f.type()))
53+
.collect(Collectors.toList());
4454
}
4555

4656
public void write(byte[] bytes) throws Exception {
4757
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
4858
ArrowStreamReader arrowStreamReader = new ArrowStreamReader(bais, allocator);
4959
VectorSchemaRoot vsr = arrowStreamReader.getVectorSchemaRoot();
60+
if (!checkSchema(arrowFields, vsr.getSchema().getFields())) {
61+
throw new RuntimeException(
62+
String.format(
63+
"Input schema isn't consistent with table schema.\n"
64+
+ "\tTable schema is: %s\n"
65+
+ "\tInput schema is: %s",
66+
arrowFields, vsr.getSchema().getFields()));
67+
}
68+
5069
while (arrowStreamReader.loadNextBatch()) {
5170
Iterable<InternalRow> rows = arrowBatchReader.readBatch(vsr);
5271
for (InternalRow row : rows) {
@@ -59,4 +78,26 @@ public void write(byte[] bytes) throws Exception {
5978
public void close() {
6079
allocator.close();
6180
}
81+
82+
private boolean checkSchema(List<Field> expectedFields, List<Field> actualFields) {
83+
if (expectedFields.size() != actualFields.size()) {
84+
return false;
85+
}
86+
87+
for (int i = 0; i < expectedFields.size(); i++) {
88+
Field expectedField = expectedFields.get(i);
89+
Field actualField = actualFields.get(i);
90+
if (!checkField(expectedField, actualField)
91+
|| !checkSchema(expectedField.getChildren(), actualField.getChildren())) {
92+
return false;
93+
}
94+
}
95+
96+
return true;
97+
}
98+
99+
private boolean checkField(Field expected, Field actual) {
100+
return Objects.equals(expected.getName(), actual.getName())
101+
&& Objects.equals(expected.getType(), actual.getType());
102+
}
62103
}

java_based_implementation/tests/test_write_and_read.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from java_based_implementation.java_gateway import get_gateway
2727
from java_based_implementation.tests.utils import set_bridge_jar, create_simple_table
2828
from java_based_implementation.util import constants, java_utils
29+
from py4j.protocol import Py4JJavaError
2930

3031

3132
class TableWriteReadTest(unittest.TestCase):
@@ -136,3 +137,49 @@ def testWriteReadAppendTable(self):
136137

137138
# check data
138139
pd.testing.assert_frame_equal(result, df)
140+
141+
def testWriteWrongSchema(self):
142+
create_simple_table(self.warehouse, 'default', 'test_wrong_schema', False)
143+
144+
catalog = Catalog.create({'warehouse': self.warehouse})
145+
table = catalog.get_table('default.test_wrong_schema')
146+
147+
data = {
148+
'f0': [1, 2, 3],
149+
'f1': ['a', 'b', 'c'],
150+
}
151+
df = pd.DataFrame(data)
152+
schema = pa.schema([
153+
('f0', pa.int64()),
154+
('f1', pa.string())
155+
])
156+
record_batch = pa.RecordBatch.from_pandas(df, schema)
157+
158+
write_builder = table.new_batch_write_builder()
159+
table_write = write_builder.new_write()
160+
161+
with self.assertRaises(Py4JJavaError) as e:
162+
table_write.write(record_batch)
163+
self.assertEqual(
164+
str(e.exception.java_exception),
165+
'''java.lang.RuntimeException: Input schema isn't consistent with table schema.
166+
\tTable schema is: [f0: Int(32, true), f1: Utf8]
167+
\tInput schema is: [f0: Int(64, true), f1: Utf8]''')
168+
169+
def testCannotWriteDynamicBucketTable(self):
170+
create_simple_table(
171+
self.warehouse,
172+
'default',
173+
'test_dynamic_bucket',
174+
True,
175+
{'bucket': '-1'}
176+
)
177+
178+
catalog = Catalog.create({'warehouse': self.warehouse})
179+
table = catalog.get_table('default.test_dynamic_bucket')
180+
181+
with self.assertRaises(TypeError) as e:
182+
table.new_batch_write_builder()
183+
self.assertEqual(
184+
str(e.exception),
185+
"Doesn't support writing dynamic bucket or cross partition table.")

java_based_implementation/tests/utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,13 @@ def set_bridge_jar() -> str:
4545
return os.path.join(temp_dir, jar_name)
4646

4747

48-
def create_simple_table(warehouse, database, table_name, has_pk):
48+
def create_simple_table(warehouse, database, table_name, has_pk, options=None):
49+
if options is None:
50+
options = {
51+
'bucket': '1',
52+
'bucket-key': 'f0'
53+
}
54+
4955
gateway = get_gateway()
5056

5157
j_catalog_context = to_j_catalog_context({'warehouse': warehouse})
@@ -55,8 +61,7 @@ def create_simple_table(warehouse, database, table_name, has_pk):
5561
gateway.jvm.Schema.newBuilder()
5662
.column('f0', gateway.jvm.DataTypes.INT())
5763
.column('f1', gateway.jvm.DataTypes.STRING())
58-
.option('bucket', '1')
59-
.option('bucket-key', 'f0')
64+
.options(options)
6065
)
6166
if has_pk:
6267
j_schema_builder.primaryKey(['f0'])

java_based_implementation/util/java_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,11 @@ def to_j_catalog_context(catalog_context: dict):
2323
gateway = get_gateway()
2424
j_options = gateway.jvm.Options(catalog_context)
2525
return gateway.jvm.CatalogContext.create(j_options)
26+
27+
28+
def check_batch_write(j_table):
29+
gateway = get_gateway()
30+
bucket_mode = j_table.bucketMode()
31+
if bucket_mode == gateway.jvm.BucketMode.HASH_DYNAMIC \
32+
or bucket_mode == gateway.jvm.BucketMode.CROSS_PARTITION:
33+
raise TypeError("Doesn't support writing dynamic bucket or cross partition table.")

0 commit comments

Comments
 (0)