From f32b22ec517ef62c47b4e11956812c2dc4b4425a Mon Sep 17 00:00:00 2001 From: yuzelin Date: Tue, 20 Aug 2024 14:27:32 +0800 Subject: [PATCH] fix --- java_based_implementation/api_impl.py | 13 +++- java_based_implementation/java_gateway.py | 3 +- .../tests/test_write_and_read.py | 69 +++++++++++++++++-- java_based_implementation/tests/utils.py | 2 +- 4 files changed, 79 insertions(+), 8 deletions(-) diff --git a/java_based_implementation/api_impl.py b/java_based_implementation/api_impl.py index 1d04300..6c1be64 100644 --- a/java_based_implementation/api_impl.py +++ b/java_based_implementation/api_impl.py @@ -124,7 +124,11 @@ def create_reader(self, split: Split): self._j_bytes_reader.setSplit(split.to_j_split()) batch_iterator = self._batch_generator() # to init arrow schema - first_batch = next(batch_iterator) + try: + first_batch = next(batch_iterator) + except StopIteration: + return self._empty_batch_reader() + batches = itertools.chain((b for b in [first_batch]), batch_iterator) return RecordBatchReader.from_batches(self._arrow_schema, batches) @@ -139,6 +143,13 @@ def _batch_generator(self) -> Iterator[RecordBatch]: self._arrow_schema = stream_reader.schema yield from stream_reader + def _empty_batch_reader(self): + import pyarrow as pa + schema = pa.schema([]) + empty_batch = pa.RecordBatch.from_arrays([], schema=schema) + empty_reader = pa.RecordBatchReader.from_batches(schema, [empty_batch]) + return empty_reader + def close(self): self._j_bytes_reader.close() diff --git a/java_based_implementation/java_gateway.py b/java_based_implementation/java_gateway.py index f9652f2..64f40b4 100644 --- a/java_based_implementation/java_gateway.py +++ b/java_based_implementation/java_gateway.py @@ -106,8 +106,9 @@ def import_paimon_view(gateway): java_import(gateway.jvm, "org.apache.paimon.options.Options") java_import(gateway.jvm, "org.apache.paimon.catalog.*") java_import(gateway.jvm, "org.apache.paimon.schema.Schema*") - java_import(gateway.jvm, 'org.apache.paimon.types.DataTypes') + java_import(gateway.jvm, 'org.apache.paimon.types.*') java_import(gateway.jvm, 'org.apache.paimon.python.InvocationUtil') + java_import(gateway.jvm, "org.apache.paimon.data.*") class Watchdog(object): diff --git a/java_based_implementation/tests/test_write_and_read.py b/java_based_implementation/tests/test_write_and_read.py index acfdba0..cfb239c 100644 --- a/java_based_implementation/tests/test_write_and_read.py +++ b/java_based_implementation/tests/test_write_and_read.py @@ -22,9 +22,10 @@ import pandas as pd import pyarrow as pa -from java_based_implementation.api_impl import Catalog +from java_based_implementation.api_impl import Catalog, Table +from java_based_implementation.java_gateway import get_gateway from java_based_implementation.tests.utils import set_bridge_jar, create_simple_table -from java_based_implementation.util import constants +from java_based_implementation.util import constants, java_utils class TableWriteReadTest(unittest.TestCase): @@ -33,12 +34,70 @@ class TableWriteReadTest(unittest.TestCase): def setUpClass(cls): classpath = set_bridge_jar() os.environ[constants.PYPAIMON_JAVA_CLASSPATH] = classpath + cls.warehouse = tempfile.mkdtemp() + + def testReadEmptyAppendTable(self): + create_simple_table(self.warehouse, 'default', 'empty_append_table', False) + catalog = Catalog.create({'warehouse': self.warehouse}) + table = catalog.get_table('default.empty_append_table') + + # read data + read_builder = table.new_read_builder() + table_scan = read_builder.new_scan() + splits = table_scan.plan().splits() + + self.assertTrue(len(splits) == 0) + + def testReadEmptyPkTable(self): + create_simple_table(self.warehouse, 'default', 'empty_pk_table', True) + gateway = get_gateway() + j_catalog_context = java_utils.to_j_catalog_context({'warehouse': self.warehouse}) + j_catalog = gateway.jvm.CatalogFactory.createCatalog(j_catalog_context) + j_identifier = gateway.jvm.Identifier.fromString('default.empty_pk_table') + j_table = j_catalog.getTable(j_identifier) + j_write_builder = gateway.jvm.InvocationUtil.getBatchWriteBuilder(j_table) + + # first commit + generic_row = gateway.jvm.GenericRow(gateway.jvm.RowKind.INSERT, 2) + generic_row.setField(0, 1) + generic_row.setField(1, gateway.jvm.BinaryString.fromString('a')) + table_write = j_write_builder.newWrite() + table_write.write(generic_row) + table_commit = j_write_builder.newCommit() + table_commit.commit(table_write.prepareCommit()) + table_write.close() + table_commit.close() + + # second commit + generic_row = gateway.jvm.GenericRow(gateway.jvm.RowKind.DELETE, 2) + generic_row.setField(0, 1) + generic_row.setField(1, gateway.jvm.BinaryString.fromString('a')) + table_write = j_write_builder.newWrite() + table_write.write(generic_row) + table_commit = j_write_builder.newCommit() + table_commit.commit(table_write.prepareCommit()) + table_write.close() + table_commit.close() + + # read data + table = Table(j_table) + read_builder = table.new_read_builder() + table_scan = read_builder.new_scan() + table_read = read_builder.new_read() + splits = table_scan.plan().splits() + + data_frames = [ + batch.to_pandas() + for split in splits + for batch in table_read.create_reader(split) + ] + result = pd.concat(data_frames) + self.assertEqual(result.shape, (0, 0)) def testWriteReadAppendTable(self): - warehouse = tempfile.mkdtemp() - create_simple_table(warehouse, 'default', 'simple_append_table', False) + create_simple_table(self.warehouse, 'default', 'simple_append_table', False) - catalog = Catalog.create({'warehouse': warehouse}) + catalog = Catalog.create({'warehouse': self.warehouse}) table = catalog.get_table('default.simple_append_table') # prepare data diff --git a/java_based_implementation/tests/utils.py b/java_based_implementation/tests/utils.py index ce2b664..86276f2 100644 --- a/java_based_implementation/tests/utils.py +++ b/java_based_implementation/tests/utils.py @@ -59,7 +59,7 @@ def create_simple_table(warehouse, database, table_name, has_pk): .option('bucket-key', 'f0') ) if has_pk: - j_schema_builder.primaryKey('f0') + j_schema_builder.primaryKey(['f0']) j_schema = j_schema_builder.build() j_catalog.createDatabase(database, True)