From 7e49b6671e2299d177b967d7319873921d9e6844 Mon Sep 17 00:00:00 2001 From: yuzelin <33053040+yuzelin@users.noreply.github.com> Date: Wed, 9 Oct 2024 11:35:07 +0800 Subject: [PATCH] Catalog supports to create table from pyarrow schema (#12) --- paimon_python_api/__init__.py | 3 +- paimon_python_api/catalog.py | 11 +- paimon_python_api/table.py | 19 ++++ paimon_python_java/pypaimon.py | 17 ++- paimon_python_java/tests/test_data_types.py | 106 ++++++++++++++++++ .../tests/test_write_and_read.py | 71 ++++++------ paimon_python_java/tests/utils.py | 47 -------- paimon_python_java/util/java_utils.py | 60 ++++++++++ 8 files changed, 248 insertions(+), 86 deletions(-) create mode 100644 paimon_python_java/tests/test_data_types.py delete mode 100644 paimon_python_java/tests/utils.py diff --git a/paimon_python_api/__init__.py b/paimon_python_api/__init__.py index b3842d0..86090c9 100644 --- a/paimon_python_api/__init__.py +++ b/paimon_python_api/__init__.py @@ -24,7 +24,7 @@ from .table_commit import BatchTableCommit from .table_write import BatchTableWrite from .write_builder import BatchWriteBuilder -from .table import Table +from .table import Table, Schema from .catalog import Catalog __all__ = [ @@ -38,5 +38,6 @@ 'BatchTableWrite', 'BatchWriteBuilder', 'Table', + 'Schema', 'Catalog' ] diff --git a/paimon_python_api/catalog.py b/paimon_python_api/catalog.py index 412e9f0..a4a863f 100644 --- a/paimon_python_api/catalog.py +++ b/paimon_python_api/catalog.py @@ -17,7 +17,8 @@ ################################################################################# from abc import ABC, abstractmethod -from paimon_python_api import Table +from typing import Optional +from paimon_python_api import Table, Schema class Catalog(ABC): @@ -34,3 +35,11 @@ def create(catalog_options: dict) -> 'Catalog': @abstractmethod def get_table(self, identifier: str) -> Table: """Get paimon table identified by the given Identifier.""" + + @abstractmethod + def create_database(self, name: str, ignore_if_exists: bool, properties: Optional[dict] = None): + """Create a database with properties.""" + + @abstractmethod + def create_table(self, identifier: str, schema: Schema, ignore_if_exists: bool): + """Create table.""" diff --git a/paimon_python_api/table.py b/paimon_python_api/table.py index 35b81ac..0170cb1 100644 --- a/paimon_python_api/table.py +++ b/paimon_python_api/table.py @@ -16,8 +16,11 @@ # limitations under the License. ################################################################################# +import pyarrow as pa + from abc import ABC, abstractmethod from paimon_python_api import ReadBuilder, BatchWriteBuilder +from typing import Optional, List class Table(ABC): @@ -30,3 +33,19 @@ def new_read_builder(self) -> ReadBuilder: @abstractmethod def new_batch_write_builder(self) -> BatchWriteBuilder: """Returns a builder for building batch table write and table commit.""" + + +class Schema: + """Schema of a table.""" + + def __init__(self, + pa_schema: pa.Schema, + partition_keys: Optional[List[str]] = None, + primary_keys: Optional[List[str]] = None, + options: Optional[dict] = None, + comment: Optional[str] = None): + self.pa_schema = pa_schema + self.partition_keys = partition_keys + self.primary_keys = primary_keys + self.options = options + self.comment = comment diff --git a/paimon_python_java/pypaimon.py b/paimon_python_java/pypaimon.py index ee43a17..263491e 100644 --- a/paimon_python_java/pypaimon.py +++ b/paimon_python_java/pypaimon.py @@ -21,8 +21,8 @@ from paimon_python_java.java_gateway import get_gateway from paimon_python_java.util import java_utils, constants from paimon_python_api import (catalog, table, read_builder, table_scan, split, table_read, - write_builder, table_write, commit_message, table_commit) -from typing import List, Iterator + write_builder, table_write, commit_message, table_commit, Schema) +from typing import List, Iterator, Optional class Catalog(catalog.Catalog): @@ -39,11 +39,20 @@ def create(catalog_options: dict) -> 'Catalog': return Catalog(j_catalog, catalog_options) def get_table(self, identifier: str) -> 'Table': - gateway = get_gateway() - j_identifier = gateway.jvm.Identifier.fromString(identifier) + j_identifier = java_utils.to_j_identifier(identifier) j_table = self._j_catalog.getTable(j_identifier) return Table(j_table, self._catalog_options) + def create_database(self, name: str, ignore_if_exists: bool, properties: Optional[dict] = None): + if properties is None: + properties = {} + self._j_catalog.createDatabase(name, ignore_if_exists, properties) + + def create_table(self, identifier: str, schema: Schema, ignore_if_exists: bool): + j_identifier = java_utils.to_j_identifier(identifier) + j_schema = java_utils.to_paimon_schema(schema) + self._j_catalog.createTable(j_identifier, j_schema, ignore_if_exists) + class Table(table.Table): diff --git a/paimon_python_java/tests/test_data_types.py b/paimon_python_java/tests/test_data_types.py new file mode 100644 index 0000000..72a4587 --- /dev/null +++ b/paimon_python_java/tests/test_data_types.py @@ -0,0 +1,106 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import os +import random +import shutil +import string +import tempfile +import pyarrow as pa +import unittest + + +from paimon_python_api import Schema +from paimon_python_java import Catalog +from paimon_python_java.util import java_utils +from setup_utils import java_setuputils + + +class DataTypesTest(unittest.TestCase): + + @classmethod + def setUpClass(cls): + java_setuputils.setup_java_bridge() + cls.warehouse = tempfile.mkdtemp() + cls.simple_pa_schema = pa.schema([ + ('f0', pa.int32()), + ('f1', pa.string()) + ]) + cls.catalog = Catalog.create({'warehouse': cls.warehouse}) + cls.catalog.create_database('default', False) + + @classmethod + def tearDownClass(cls): + java_setuputils.clean() + if os.path.exists(cls.warehouse): + shutil.rmtree(cls.warehouse) + + def test_int(self): + pa_schema = pa.schema([ + ('_int8', pa.int8()), + ('_int16', pa.int16()), + ('_int32', pa.int32()), + ('_int64', pa.int64()) + ]) + expected_types = ['TINYINT', 'SMALLINT', 'INT', 'BIGINT'] + self._test_impl(pa_schema, expected_types) + + def test_float(self): + pa_schema = pa.schema([ + ('_float16', pa.float16()), + ('_float32', pa.float32()), + ('_float64', pa.float64()) + ]) + expected_types = ['FLOAT', 'FLOAT', 'DOUBLE'] + self._test_impl(pa_schema, expected_types) + + def test_string(self): + pa_schema = pa.schema([ + ('_string', pa.string()), + ('_utf8', pa.utf8()) + ]) + expected_types = ['STRING', 'STRING'] + self._test_impl(pa_schema, expected_types) + + def test_bool(self): + pa_schema = pa.schema([('_bool', pa.bool_())]) + expected_types = ['BOOLEAN'] + self._test_impl(pa_schema, expected_types) + + def test_null(self): + pa_schema = pa.schema([('_null', pa.null())]) + expected_types = ['STRING'] + self._test_impl(pa_schema, expected_types) + + def test_unsupported_type(self): + pa_schema = pa.schema([('_array', pa.list_(pa.int32()))]) + schema = Schema(pa_schema) + with self.assertRaises(ValueError) as e: + java_utils.to_paimon_schema(schema) + self.assertEqual( + str(e.exception), 'Found unsupported data type list for field _array.') + + def _test_impl(self, pa_schema, expected_types): + scheme = Schema(pa_schema) + letters = string.ascii_letters + identifier = 'default.' + ''.join(random.choice(letters) for _ in range(10)) + self.catalog.create_table(identifier, scheme, False) + table = self.catalog.get_table(identifier) + field_types = table._j_table.rowType().getFieldTypes() + actual_types = list(map(lambda t: t.toString(), field_types)) + self.assertListEqual(actual_types, expected_types) diff --git a/paimon_python_java/tests/test_write_and_read.py b/paimon_python_java/tests/test_write_and_read.py index 19c9ea8..cd2bd0e 100644 --- a/paimon_python_java/tests/test_write_and_read.py +++ b/paimon_python_java/tests/test_write_and_read.py @@ -16,15 +16,17 @@ # limitations under the License. ################################################################################ +import os +import shutil import tempfile import unittest import pandas as pd import pyarrow as pa import setup_utils.java_setuputils as setuputils -from paimon_python_java import Catalog, Table +from paimon_python_api import Schema +from paimon_python_java import Catalog from paimon_python_java.java_gateway import get_gateway -from paimon_python_java.tests.utils import create_simple_table from paimon_python_java.util import java_utils from py4j.protocol import Py4JJavaError @@ -35,15 +37,23 @@ class TableWriteReadTest(unittest.TestCase): def setUpClass(cls): setuputils.setup_java_bridge() cls.warehouse = tempfile.mkdtemp() + cls.simple_pa_schema = pa.schema([ + ('f0', pa.int32()), + ('f1', pa.string()) + ]) + cls.catalog = Catalog.create({'warehouse': cls.warehouse}) + cls.catalog.create_database('default', False) @classmethod def tearDownClass(cls): setuputils.clean() + if os.path.exists(cls.warehouse): + shutil.rmtree(cls.warehouse) def testReadEmptyAppendTable(self): - create_simple_table(self.warehouse, 'default', 'empty_append_table', False) - catalog = Catalog.create({'warehouse': self.warehouse}) - table = catalog.get_table('default.empty_append_table') + schema = Schema(self.simple_pa_schema) + self.catalog.create_table('default.empty_append_table', schema, False) + table = self.catalog.get_table('default.empty_append_table') # read data read_builder = table.new_read_builder() @@ -53,7 +63,10 @@ def testReadEmptyAppendTable(self): self.assertTrue(len(splits) == 0) def testReadEmptyPkTable(self): - create_simple_table(self.warehouse, 'default', 'empty_pk_table', True) + schema = Schema(self.simple_pa_schema, primary_keys=['f0'], options={'bucket': '1'}) + self.catalog.create_table('default.empty_pk_table', schema, False) + + # use Java API to generate data gateway = get_gateway() j_catalog_context = java_utils.to_j_catalog_context({'warehouse': self.warehouse}) j_catalog = gateway.jvm.CatalogFactory.createCatalog(j_catalog_context) @@ -84,7 +97,7 @@ def testReadEmptyPkTable(self): table_commit.close() # read data - table = Table(j_table, {}) + table = self.catalog.get_table('default.empty_pk_table') read_builder = table.new_read_builder() table_scan = read_builder.new_scan() table_read = read_builder.new_read() @@ -98,10 +111,9 @@ def testReadEmptyPkTable(self): self.assertEqual(len(data_frames), 0) def testWriteReadAppendTable(self): - create_simple_table(self.warehouse, 'default', 'simple_append_table', False) - - catalog = Catalog.create({'warehouse': self.warehouse}) - table = catalog.get_table('default.simple_append_table') + schema = Schema(self.simple_pa_schema) + self.catalog.create_table('default.simple_append_table', schema, False) + table = self.catalog.get_table('default.simple_append_table') # prepare data data = { @@ -109,8 +121,7 @@ def testWriteReadAppendTable(self): 'f1': ['a', 'b', 'c'], } df = pd.DataFrame(data) - df['f0'] = df['f0'].astype('int32') - record_batch = pa.RecordBatch.from_pandas(df) + record_batch = pa.RecordBatch.from_pandas(df, schema=self.simple_pa_schema) # write and commit data write_builder = table.new_batch_write_builder() @@ -138,13 +149,15 @@ def testWriteReadAppendTable(self): result = pd.concat(data_frames) # check data (ignore index) - pd.testing.assert_frame_equal(result.reset_index(drop=True), df.reset_index(drop=True)) + expected = df + expected['f0'] = df['f0'].astype('int32') + pd.testing.assert_frame_equal( + result.reset_index(drop=True), expected.reset_index(drop=True)) def testWriteWrongSchema(self): - create_simple_table(self.warehouse, 'default', 'test_wrong_schema', False) - - catalog = Catalog.create({'warehouse': self.warehouse}) - table = catalog.get_table('default.test_wrong_schema') + schema = Schema(self.simple_pa_schema) + self.catalog.create_table('default.test_wrong_schema', schema, False) + table = self.catalog.get_table('default.test_wrong_schema') data = { 'f0': [1, 2, 3], @@ -155,7 +168,7 @@ def testWriteWrongSchema(self): ('f0', pa.int64()), ('f1', pa.string()) ]) - record_batch = pa.RecordBatch.from_pandas(df, schema) + record_batch = pa.RecordBatch.from_pandas(df, schema=schema) write_builder = table.new_batch_write_builder() table_write = write_builder.new_write() @@ -169,16 +182,9 @@ def testWriteWrongSchema(self): \tInput schema is: [f0: Int(64, true), f1: Utf8]''') def testCannotWriteDynamicBucketTable(self): - create_simple_table( - self.warehouse, - 'default', - 'test_dynamic_bucket', - True, - {'bucket': '-1'} - ) - - catalog = Catalog.create({'warehouse': self.warehouse}) - table = catalog.get_table('default.test_dynamic_bucket') + schema = Schema(self.simple_pa_schema, primary_keys=['f0']) + self.catalog.create_table('default.test_dynamic_bucket', schema, False) + table = self.catalog.get_table('default.test_dynamic_bucket') with self.assertRaises(TypeError) as e: table.new_batch_write_builder() @@ -187,9 +193,9 @@ def testCannotWriteDynamicBucketTable(self): "Doesn't support writing dynamic bucket or cross partition table.") def testParallelRead(self): - create_simple_table(self.warehouse, 'default', 'test_parallel_read', False) - catalog = Catalog.create({'warehouse': self.warehouse, 'max-workers': '2'}) + schema = Schema(self.simple_pa_schema) + catalog.create_table('default.test_parallel_read', schema, False) table = catalog.get_table('default.test_parallel_read') # prepare data @@ -207,8 +213,7 @@ def testParallelRead(self): expected_data['f1'].append(str(i * 2)) df = pd.DataFrame(data) - df['f0'] = df['f0'].astype('int32') - record_batch = pa.RecordBatch.from_pandas(df) + record_batch = pa.RecordBatch.from_pandas(df, schema=self.simple_pa_schema) # write and commit data write_builder = table.new_batch_write_builder() diff --git a/paimon_python_java/tests/utils.py b/paimon_python_java/tests/utils.py deleted file mode 100644 index e0a79d5..0000000 --- a/paimon_python_java/tests/utils.py +++ /dev/null @@ -1,47 +0,0 @@ -################################################################################ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -################################################################################ - -from paimon_python_java.java_gateway import get_gateway -from paimon_python_java.util import java_utils - - -def create_simple_table(warehouse, database, table_name, has_pk, options=None): - if options is None: - options = { - 'bucket': '1', - 'bucket-key': 'f0' - } - - gateway = get_gateway() - - j_catalog_context = java_utils.to_j_catalog_context({'warehouse': warehouse}) - j_catalog = gateway.jvm.CatalogFactory.createCatalog(j_catalog_context) - - j_schema_builder = ( - gateway.jvm.Schema.newBuilder() - .column('f0', gateway.jvm.DataTypes.INT()) - .column('f1', gateway.jvm.DataTypes.STRING()) - .options(options) - ) - if has_pk: - j_schema_builder.primaryKey(['f0']) - j_schema = j_schema_builder.build() - - j_catalog.createDatabase(database, True) - j_identifier = gateway.jvm.Identifier(database, table_name) - j_catalog.createTable(j_identifier, j_schema, False) diff --git a/paimon_python_java/util/java_utils.py b/paimon_python_java/util/java_utils.py index 6ce6ede..8c4f276 100644 --- a/paimon_python_java/util/java_utils.py +++ b/paimon_python_java/util/java_utils.py @@ -16,6 +16,9 @@ # limitations under the License. ################################################################################ +import pyarrow as pa + +from paimon_python_api import Schema from paimon_python_java.java_gateway import get_gateway @@ -25,9 +28,66 @@ def to_j_catalog_context(catalog_options: dict): return gateway.jvm.CatalogContext.create(j_options) +def to_j_identifier(identifier: str): + return get_gateway().jvm.Identifier.fromString(identifier) + + +def to_paimon_schema(schema: Schema): + j_schema_builder = get_gateway().jvm.Schema.newBuilder() + + if schema.partition_keys is not None: + j_schema_builder.partitionKeys(schema.partition_keys) + + if schema.primary_keys is not None: + j_schema_builder.primaryKey(schema.primary_keys) + + if schema.options is not None: + j_schema_builder.options(schema.options) + + j_schema_builder.comment(schema.comment) + + for field in schema.pa_schema: + column_name = field.name + column_type = _to_j_type(column_name, field.type) + j_schema_builder.column(column_name, column_type) + return j_schema_builder.build() + + def check_batch_write(j_table): gateway = get_gateway() bucket_mode = j_table.bucketMode() if bucket_mode == gateway.jvm.BucketMode.HASH_DYNAMIC \ or bucket_mode == gateway.jvm.BucketMode.CROSS_PARTITION: raise TypeError("Doesn't support writing dynamic bucket or cross partition table.") + + +def _to_j_type(name, pa_type): + jvm = get_gateway().jvm + # int + if pa.types.is_int8(pa_type): + return jvm.DataTypes.TINYINT() + elif pa.types.is_int16(pa_type): + return jvm.DataTypes.SMALLINT() + elif pa.types.is_int32(pa_type): + return jvm.DataTypes.INT() + elif pa.types.is_int64(pa_type): + return jvm.DataTypes.BIGINT() + # float + elif pa.types.is_float16(pa_type) or pa.types.is_float32(pa_type): + return jvm.DataTypes.FLOAT() + elif pa.types.is_float64(pa_type): + return jvm.DataTypes.DOUBLE() + # string + elif pa.types.is_string(pa_type): + return jvm.DataTypes.STRING() + # bool + elif pa.types.is_boolean(pa_type): + return jvm.DataTypes.BOOLEAN() + elif pa.types.is_null(pa_type): + print(f"WARN: The type of column '{name}' is null, " + "and it will be converted to string type by default. " + "Please check if the original type is string. " + f"If not, please manually specify the type of '{name}'.") + return jvm.DataTypes.STRING() + else: + raise ValueError(f'Found unsupported data type {str(pa_type)} for field {name}.')