Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
yuzelin committed Aug 20, 2024
1 parent 698305b commit f32b22e
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 8 deletions.
13 changes: 12 additions & 1 deletion java_based_implementation/api_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,11 @@ def create_reader(self, split: Split):
self._j_bytes_reader.setSplit(split.to_j_split())
batch_iterator = self._batch_generator()
# to init arrow schema
first_batch = next(batch_iterator)
try:
first_batch = next(batch_iterator)
except StopIteration:
return self._empty_batch_reader()

batches = itertools.chain((b for b in [first_batch]), batch_iterator)
return RecordBatchReader.from_batches(self._arrow_schema, batches)

Expand All @@ -139,6 +143,13 @@ def _batch_generator(self) -> Iterator[RecordBatch]:
self._arrow_schema = stream_reader.schema
yield from stream_reader

def _empty_batch_reader(self):
import pyarrow as pa
schema = pa.schema([])
empty_batch = pa.RecordBatch.from_arrays([], schema=schema)
empty_reader = pa.RecordBatchReader.from_batches(schema, [empty_batch])
return empty_reader

def close(self):
self._j_bytes_reader.close()

Expand Down
3 changes: 2 additions & 1 deletion java_based_implementation/java_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,9 @@ def import_paimon_view(gateway):
java_import(gateway.jvm, "org.apache.paimon.options.Options")
java_import(gateway.jvm, "org.apache.paimon.catalog.*")
java_import(gateway.jvm, "org.apache.paimon.schema.Schema*")
java_import(gateway.jvm, 'org.apache.paimon.types.DataTypes')
java_import(gateway.jvm, 'org.apache.paimon.types.*')
java_import(gateway.jvm, 'org.apache.paimon.python.InvocationUtil')
java_import(gateway.jvm, "org.apache.paimon.data.*")


class Watchdog(object):
Expand Down
69 changes: 64 additions & 5 deletions java_based_implementation/tests/test_write_and_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
import pandas as pd
import pyarrow as pa

from java_based_implementation.api_impl import Catalog
from java_based_implementation.api_impl import Catalog, Table
from java_based_implementation.java_gateway import get_gateway
from java_based_implementation.tests.utils import set_bridge_jar, create_simple_table
from java_based_implementation.util import constants
from java_based_implementation.util import constants, java_utils


class TableWriteReadTest(unittest.TestCase):
Expand All @@ -33,12 +34,70 @@ class TableWriteReadTest(unittest.TestCase):
def setUpClass(cls):
classpath = set_bridge_jar()
os.environ[constants.PYPAIMON_JAVA_CLASSPATH] = classpath
cls.warehouse = tempfile.mkdtemp()

def testReadEmptyAppendTable(self):
create_simple_table(self.warehouse, 'default', 'empty_append_table', False)
catalog = Catalog.create({'warehouse': self.warehouse})
table = catalog.get_table('default.empty_append_table')

# read data
read_builder = table.new_read_builder()
table_scan = read_builder.new_scan()
splits = table_scan.plan().splits()

self.assertTrue(len(splits) == 0)

def testReadEmptyPkTable(self):
create_simple_table(self.warehouse, 'default', 'empty_pk_table', True)
gateway = get_gateway()
j_catalog_context = java_utils.to_j_catalog_context({'warehouse': self.warehouse})
j_catalog = gateway.jvm.CatalogFactory.createCatalog(j_catalog_context)
j_identifier = gateway.jvm.Identifier.fromString('default.empty_pk_table')
j_table = j_catalog.getTable(j_identifier)
j_write_builder = gateway.jvm.InvocationUtil.getBatchWriteBuilder(j_table)

# first commit
generic_row = gateway.jvm.GenericRow(gateway.jvm.RowKind.INSERT, 2)
generic_row.setField(0, 1)
generic_row.setField(1, gateway.jvm.BinaryString.fromString('a'))
table_write = j_write_builder.newWrite()
table_write.write(generic_row)
table_commit = j_write_builder.newCommit()
table_commit.commit(table_write.prepareCommit())
table_write.close()
table_commit.close()

# second commit
generic_row = gateway.jvm.GenericRow(gateway.jvm.RowKind.DELETE, 2)
generic_row.setField(0, 1)
generic_row.setField(1, gateway.jvm.BinaryString.fromString('a'))
table_write = j_write_builder.newWrite()
table_write.write(generic_row)
table_commit = j_write_builder.newCommit()
table_commit.commit(table_write.prepareCommit())
table_write.close()
table_commit.close()

# read data
table = Table(j_table)
read_builder = table.new_read_builder()
table_scan = read_builder.new_scan()
table_read = read_builder.new_read()
splits = table_scan.plan().splits()

data_frames = [
batch.to_pandas()
for split in splits
for batch in table_read.create_reader(split)
]
result = pd.concat(data_frames)
self.assertEqual(result.shape, (0, 0))

def testWriteReadAppendTable(self):
warehouse = tempfile.mkdtemp()
create_simple_table(warehouse, 'default', 'simple_append_table', False)
create_simple_table(self.warehouse, 'default', 'simple_append_table', False)

catalog = Catalog.create({'warehouse': warehouse})
catalog = Catalog.create({'warehouse': self.warehouse})
table = catalog.get_table('default.simple_append_table')

# prepare data
Expand Down
2 changes: 1 addition & 1 deletion java_based_implementation/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def create_simple_table(warehouse, database, table_name, has_pk):
.option('bucket-key', 'f0')
)
if has_pk:
j_schema_builder.primaryKey('f0')
j_schema_builder.primaryKey(['f0'])
j_schema = j_schema_builder.build()

j_catalog.createDatabase(database, True)
Expand Down

0 comments on commit f32b22e

Please sign in to comment.