diff --git a/paimon_python_api/write_builder.py b/paimon_python_api/write_builder.py index caa23b9..7835179 100644 --- a/paimon_python_api/write_builder.py +++ b/paimon_python_api/write_builder.py @@ -18,16 +18,17 @@ from abc import ABC, abstractmethod from paimon_python_api import BatchTableCommit, BatchTableWrite +from typing import Optional class BatchWriteBuilder(ABC): """An interface for building the TableScan and TableRead.""" @abstractmethod - def with_overwrite(self, static_partition: dict) -> 'BatchWriteBuilder': + def overwrite(self, static_partition: Optional[dict] = None) -> 'BatchWriteBuilder': """ Overwrite writing, same as the 'INSERT OVERWRITE T PARTITION (...)' semantics of SQL. - If you pass an empty dict, it means OVERWRITE whole table. + If you pass None, it means OVERWRITE whole table. """ @abstractmethod diff --git a/paimon_python_java/pypaimon.py b/paimon_python_java/pypaimon.py index a007add..fcf0695 100644 --- a/paimon_python_java/pypaimon.py +++ b/paimon_python_java/pypaimon.py @@ -185,7 +185,9 @@ def __init__(self, j_batch_write_builder, j_row_type, arrow_schema: pa.Schema): self._j_row_type = j_row_type self._arrow_schema = arrow_schema - def with_overwrite(self, static_partition: dict) -> 'BatchWriteBuilder': + def overwrite(self, static_partition: Optional[dict] = None) -> 'BatchWriteBuilder': + if static_partition is None: + static_partition = {} self._j_batch_write_builder.withOverwrite(static_partition) return self diff --git a/paimon_python_java/tests/test_write_and_read.py b/paimon_python_java/tests/test_write_and_read.py index c4e78c9..4cf7548 100644 --- a/paimon_python_java/tests/test_write_and_read.py +++ b/paimon_python_java/tests/test_write_and_read.py @@ -297,3 +297,77 @@ def testAllWriteAndReadApi(self): actual = table_read.to_pandas(splits) pd.testing.assert_frame_equal( actual.reset_index(drop=True), expected.reset_index(drop=True)) + + def test_overwrite(self): + schema = Schema(self.simple_pa_schema, partition_keys=['f0'], + options={'dynamic-partition-overwrite': 'false'}) + self.catalog.create_table('default.test_overwrite', schema, False) + table = self.catalog.get_table('default.test_overwrite') + read_builder = table.new_read_builder() + + write_builder = table.new_batch_write_builder() + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + + df0 = pd.DataFrame({ + 'f0': [1, 1, 2, 2], + 'f1': ['apple', 'banana', 'dog', 'cat'], + }) + + table_write.write_pandas(df0) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + table_scan = read_builder.new_scan() + table_read = read_builder.new_read() + actual_df0 = table_read.to_pandas(table_scan.plan().splits()).sort_values(by='f0') + df0['f0'] = df0['f0'].astype('int32') + pd.testing.assert_frame_equal( + actual_df0.reset_index(drop=True), df0.reset_index(drop=True)) + + write_builder = table.new_batch_write_builder().overwrite({'f0': '1'}) + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + + df1 = pd.DataFrame({ + 'f0': [1], + 'f1': ['watermelon'], + }) + + table_write.write_pandas(df1) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + table_scan = read_builder.new_scan() + table_read = read_builder.new_read() + actual_df1 = table_read.to_pandas(table_scan.plan().splits()) + expected_df1 = pd.DataFrame({ + 'f0': [2, 2, 1], + 'f1': ['dog', 'cat', 'watermelon'] + }) + expected_df1['f0'] = expected_df1['f0'].astype('int32') + pd.testing.assert_frame_equal( + actual_df1.reset_index(drop=True), expected_df1.reset_index(drop=True)) + + write_builder = table.new_batch_write_builder().overwrite() + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + + df2 = pd.DataFrame({ + 'f0': [3], + 'f1': ['Neo'], + }) + + table_write.write_pandas(df2) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + table_scan = read_builder.new_scan() + table_read = read_builder.new_read() + actual_df2 = table_read.to_pandas(table_scan.plan().splits()) + df2['f0'] = df2['f0'].astype('int32') + pd.testing.assert_frame_equal( + actual_df2.reset_index(drop=True), df2.reset_index(drop=True))