Skip to content

Commit

Permalink
Catalog supports to create table from pyarrow schema (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuzelin authored Oct 9, 2024
1 parent a7b752a commit 7e49b66
Show file tree
Hide file tree
Showing 8 changed files with 248 additions and 86 deletions.
3 changes: 2 additions & 1 deletion paimon_python_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .table_commit import BatchTableCommit
from .table_write import BatchTableWrite
from .write_builder import BatchWriteBuilder
from .table import Table
from .table import Table, Schema
from .catalog import Catalog

__all__ = [
Expand All @@ -38,5 +38,6 @@
'BatchTableWrite',
'BatchWriteBuilder',
'Table',
'Schema',
'Catalog'
]
11 changes: 10 additions & 1 deletion paimon_python_api/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
#################################################################################

from abc import ABC, abstractmethod
from paimon_python_api import Table
from typing import Optional
from paimon_python_api import Table, Schema


class Catalog(ABC):
Expand All @@ -34,3 +35,11 @@ def create(catalog_options: dict) -> 'Catalog':
@abstractmethod
def get_table(self, identifier: str) -> Table:
"""Get paimon table identified by the given Identifier."""

@abstractmethod
def create_database(self, name: str, ignore_if_exists: bool, properties: Optional[dict] = None):
"""Create a database with properties."""

@abstractmethod
def create_table(self, identifier: str, schema: Schema, ignore_if_exists: bool):
"""Create table."""
19 changes: 19 additions & 0 deletions paimon_python_api/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
# limitations under the License.
#################################################################################

import pyarrow as pa

from abc import ABC, abstractmethod
from paimon_python_api import ReadBuilder, BatchWriteBuilder
from typing import Optional, List


class Table(ABC):
Expand All @@ -30,3 +33,19 @@ def new_read_builder(self) -> ReadBuilder:
@abstractmethod
def new_batch_write_builder(self) -> BatchWriteBuilder:
"""Returns a builder for building batch table write and table commit."""


class Schema:
"""Schema of a table."""

def __init__(self,
pa_schema: pa.Schema,
partition_keys: Optional[List[str]] = None,
primary_keys: Optional[List[str]] = None,
options: Optional[dict] = None,
comment: Optional[str] = None):
self.pa_schema = pa_schema
self.partition_keys = partition_keys
self.primary_keys = primary_keys
self.options = options
self.comment = comment
17 changes: 13 additions & 4 deletions paimon_python_java/pypaimon.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from paimon_python_java.java_gateway import get_gateway
from paimon_python_java.util import java_utils, constants
from paimon_python_api import (catalog, table, read_builder, table_scan, split, table_read,
write_builder, table_write, commit_message, table_commit)
from typing import List, Iterator
write_builder, table_write, commit_message, table_commit, Schema)
from typing import List, Iterator, Optional


class Catalog(catalog.Catalog):
Expand All @@ -39,11 +39,20 @@ def create(catalog_options: dict) -> 'Catalog':
return Catalog(j_catalog, catalog_options)

def get_table(self, identifier: str) -> 'Table':
gateway = get_gateway()
j_identifier = gateway.jvm.Identifier.fromString(identifier)
j_identifier = java_utils.to_j_identifier(identifier)
j_table = self._j_catalog.getTable(j_identifier)
return Table(j_table, self._catalog_options)

def create_database(self, name: str, ignore_if_exists: bool, properties: Optional[dict] = None):
if properties is None:
properties = {}
self._j_catalog.createDatabase(name, ignore_if_exists, properties)

def create_table(self, identifier: str, schema: Schema, ignore_if_exists: bool):
j_identifier = java_utils.to_j_identifier(identifier)
j_schema = java_utils.to_paimon_schema(schema)
self._j_catalog.createTable(j_identifier, j_schema, ignore_if_exists)


class Table(table.Table):

Expand Down
106 changes: 106 additions & 0 deletions paimon_python_java/tests/test_data_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################

import os
import random
import shutil
import string
import tempfile
import pyarrow as pa
import unittest


from paimon_python_api import Schema
from paimon_python_java import Catalog
from paimon_python_java.util import java_utils
from setup_utils import java_setuputils


class DataTypesTest(unittest.TestCase):

@classmethod
def setUpClass(cls):
java_setuputils.setup_java_bridge()
cls.warehouse = tempfile.mkdtemp()
cls.simple_pa_schema = pa.schema([
('f0', pa.int32()),
('f1', pa.string())
])
cls.catalog = Catalog.create({'warehouse': cls.warehouse})
cls.catalog.create_database('default', False)

@classmethod
def tearDownClass(cls):
java_setuputils.clean()
if os.path.exists(cls.warehouse):
shutil.rmtree(cls.warehouse)

def test_int(self):
pa_schema = pa.schema([
('_int8', pa.int8()),
('_int16', pa.int16()),
('_int32', pa.int32()),
('_int64', pa.int64())
])
expected_types = ['TINYINT', 'SMALLINT', 'INT', 'BIGINT']
self._test_impl(pa_schema, expected_types)

def test_float(self):
pa_schema = pa.schema([
('_float16', pa.float16()),
('_float32', pa.float32()),
('_float64', pa.float64())
])
expected_types = ['FLOAT', 'FLOAT', 'DOUBLE']
self._test_impl(pa_schema, expected_types)

def test_string(self):
pa_schema = pa.schema([
('_string', pa.string()),
('_utf8', pa.utf8())
])
expected_types = ['STRING', 'STRING']
self._test_impl(pa_schema, expected_types)

def test_bool(self):
pa_schema = pa.schema([('_bool', pa.bool_())])
expected_types = ['BOOLEAN']
self._test_impl(pa_schema, expected_types)

def test_null(self):
pa_schema = pa.schema([('_null', pa.null())])
expected_types = ['STRING']
self._test_impl(pa_schema, expected_types)

def test_unsupported_type(self):
pa_schema = pa.schema([('_array', pa.list_(pa.int32()))])
schema = Schema(pa_schema)
with self.assertRaises(ValueError) as e:
java_utils.to_paimon_schema(schema)
self.assertEqual(
str(e.exception), 'Found unsupported data type list<item: int32> for field _array.')

def _test_impl(self, pa_schema, expected_types):
scheme = Schema(pa_schema)
letters = string.ascii_letters
identifier = 'default.' + ''.join(random.choice(letters) for _ in range(10))
self.catalog.create_table(identifier, scheme, False)
table = self.catalog.get_table(identifier)
field_types = table._j_table.rowType().getFieldTypes()
actual_types = list(map(lambda t: t.toString(), field_types))
self.assertListEqual(actual_types, expected_types)
71 changes: 38 additions & 33 deletions paimon_python_java/tests/test_write_and_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,17 @@
# limitations under the License.
################################################################################

import os
import shutil
import tempfile
import unittest
import pandas as pd
import pyarrow as pa
import setup_utils.java_setuputils as setuputils

from paimon_python_java import Catalog, Table
from paimon_python_api import Schema
from paimon_python_java import Catalog
from paimon_python_java.java_gateway import get_gateway
from paimon_python_java.tests.utils import create_simple_table
from paimon_python_java.util import java_utils
from py4j.protocol import Py4JJavaError

Expand All @@ -35,15 +37,23 @@ class TableWriteReadTest(unittest.TestCase):
def setUpClass(cls):
setuputils.setup_java_bridge()
cls.warehouse = tempfile.mkdtemp()
cls.simple_pa_schema = pa.schema([
('f0', pa.int32()),
('f1', pa.string())
])
cls.catalog = Catalog.create({'warehouse': cls.warehouse})
cls.catalog.create_database('default', False)

@classmethod
def tearDownClass(cls):
setuputils.clean()
if os.path.exists(cls.warehouse):
shutil.rmtree(cls.warehouse)

def testReadEmptyAppendTable(self):
create_simple_table(self.warehouse, 'default', 'empty_append_table', False)
catalog = Catalog.create({'warehouse': self.warehouse})
table = catalog.get_table('default.empty_append_table')
schema = Schema(self.simple_pa_schema)
self.catalog.create_table('default.empty_append_table', schema, False)
table = self.catalog.get_table('default.empty_append_table')

# read data
read_builder = table.new_read_builder()
Expand All @@ -53,7 +63,10 @@ def testReadEmptyAppendTable(self):
self.assertTrue(len(splits) == 0)

def testReadEmptyPkTable(self):
create_simple_table(self.warehouse, 'default', 'empty_pk_table', True)
schema = Schema(self.simple_pa_schema, primary_keys=['f0'], options={'bucket': '1'})
self.catalog.create_table('default.empty_pk_table', schema, False)

# use Java API to generate data
gateway = get_gateway()
j_catalog_context = java_utils.to_j_catalog_context({'warehouse': self.warehouse})
j_catalog = gateway.jvm.CatalogFactory.createCatalog(j_catalog_context)
Expand Down Expand Up @@ -84,7 +97,7 @@ def testReadEmptyPkTable(self):
table_commit.close()

# read data
table = Table(j_table, {})
table = self.catalog.get_table('default.empty_pk_table')
read_builder = table.new_read_builder()
table_scan = read_builder.new_scan()
table_read = read_builder.new_read()
Expand All @@ -98,19 +111,17 @@ def testReadEmptyPkTable(self):
self.assertEqual(len(data_frames), 0)

def testWriteReadAppendTable(self):
create_simple_table(self.warehouse, 'default', 'simple_append_table', False)

catalog = Catalog.create({'warehouse': self.warehouse})
table = catalog.get_table('default.simple_append_table')
schema = Schema(self.simple_pa_schema)
self.catalog.create_table('default.simple_append_table', schema, False)
table = self.catalog.get_table('default.simple_append_table')

# prepare data
data = {
'f0': [1, 2, 3],
'f1': ['a', 'b', 'c'],
}
df = pd.DataFrame(data)
df['f0'] = df['f0'].astype('int32')
record_batch = pa.RecordBatch.from_pandas(df)
record_batch = pa.RecordBatch.from_pandas(df, schema=self.simple_pa_schema)

# write and commit data
write_builder = table.new_batch_write_builder()
Expand Down Expand Up @@ -138,13 +149,15 @@ def testWriteReadAppendTable(self):
result = pd.concat(data_frames)

# check data (ignore index)
pd.testing.assert_frame_equal(result.reset_index(drop=True), df.reset_index(drop=True))
expected = df
expected['f0'] = df['f0'].astype('int32')
pd.testing.assert_frame_equal(
result.reset_index(drop=True), expected.reset_index(drop=True))

def testWriteWrongSchema(self):
create_simple_table(self.warehouse, 'default', 'test_wrong_schema', False)

catalog = Catalog.create({'warehouse': self.warehouse})
table = catalog.get_table('default.test_wrong_schema')
schema = Schema(self.simple_pa_schema)
self.catalog.create_table('default.test_wrong_schema', schema, False)
table = self.catalog.get_table('default.test_wrong_schema')

data = {
'f0': [1, 2, 3],
Expand All @@ -155,7 +168,7 @@ def testWriteWrongSchema(self):
('f0', pa.int64()),
('f1', pa.string())
])
record_batch = pa.RecordBatch.from_pandas(df, schema)
record_batch = pa.RecordBatch.from_pandas(df, schema=schema)

write_builder = table.new_batch_write_builder()
table_write = write_builder.new_write()
Expand All @@ -169,16 +182,9 @@ def testWriteWrongSchema(self):
\tInput schema is: [f0: Int(64, true), f1: Utf8]''')

def testCannotWriteDynamicBucketTable(self):
create_simple_table(
self.warehouse,
'default',
'test_dynamic_bucket',
True,
{'bucket': '-1'}
)

catalog = Catalog.create({'warehouse': self.warehouse})
table = catalog.get_table('default.test_dynamic_bucket')
schema = Schema(self.simple_pa_schema, primary_keys=['f0'])
self.catalog.create_table('default.test_dynamic_bucket', schema, False)
table = self.catalog.get_table('default.test_dynamic_bucket')

with self.assertRaises(TypeError) as e:
table.new_batch_write_builder()
Expand All @@ -187,9 +193,9 @@ def testCannotWriteDynamicBucketTable(self):
"Doesn't support writing dynamic bucket or cross partition table.")

def testParallelRead(self):
create_simple_table(self.warehouse, 'default', 'test_parallel_read', False)

catalog = Catalog.create({'warehouse': self.warehouse, 'max-workers': '2'})
schema = Schema(self.simple_pa_schema)
catalog.create_table('default.test_parallel_read', schema, False)
table = catalog.get_table('default.test_parallel_read')

# prepare data
Expand All @@ -207,8 +213,7 @@ def testParallelRead(self):
expected_data['f1'].append(str(i * 2))

df = pd.DataFrame(data)
df['f0'] = df['f0'].astype('int32')
record_batch = pa.RecordBatch.from_pandas(df)
record_batch = pa.RecordBatch.from_pandas(df, schema=self.simple_pa_schema)

# write and commit data
write_builder = table.new_batch_write_builder()
Expand Down
Loading

0 comments on commit 7e49b66

Please sign in to comment.