Skip to content

Commit

Permalink
Fix overwrite without partition (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuzelin authored Oct 9, 2024
1 parent 7703bcd commit 9fee9b6
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 3 deletions.
5 changes: 3 additions & 2 deletions paimon_python_api/write_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@

from abc import ABC, abstractmethod
from paimon_python_api import BatchTableCommit, BatchTableWrite
from typing import Optional


class BatchWriteBuilder(ABC):
"""An interface for building the TableScan and TableRead."""

@abstractmethod
def with_overwrite(self, static_partition: dict) -> 'BatchWriteBuilder':
def overwrite(self, static_partition: Optional[dict] = None) -> 'BatchWriteBuilder':
"""
Overwrite writing, same as the 'INSERT OVERWRITE T PARTITION (...)' semantics of SQL.
If you pass an empty dict, it means OVERWRITE whole table.
If you pass None, it means OVERWRITE whole table.
"""

@abstractmethod
Expand Down
4 changes: 3 additions & 1 deletion paimon_python_java/pypaimon.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,9 @@ def __init__(self, j_batch_write_builder, j_row_type, arrow_schema: pa.Schema):
self._j_row_type = j_row_type
self._arrow_schema = arrow_schema

def with_overwrite(self, static_partition: dict) -> 'BatchWriteBuilder':
def overwrite(self, static_partition: Optional[dict] = None) -> 'BatchWriteBuilder':
if static_partition is None:
static_partition = {}
self._j_batch_write_builder.withOverwrite(static_partition)
return self

Expand Down
74 changes: 74 additions & 0 deletions paimon_python_java/tests/test_write_and_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,77 @@ def testAllWriteAndReadApi(self):
actual = table_read.to_pandas(splits)
pd.testing.assert_frame_equal(
actual.reset_index(drop=True), expected.reset_index(drop=True))

def test_overwrite(self):
schema = Schema(self.simple_pa_schema, partition_keys=['f0'],
options={'dynamic-partition-overwrite': 'false'})
self.catalog.create_table('default.test_overwrite', schema, False)
table = self.catalog.get_table('default.test_overwrite')
read_builder = table.new_read_builder()

write_builder = table.new_batch_write_builder()
table_write = write_builder.new_write()
table_commit = write_builder.new_commit()

df0 = pd.DataFrame({
'f0': [1, 1, 2, 2],
'f1': ['apple', 'banana', 'dog', 'cat'],
})

table_write.write_pandas(df0)
table_commit.commit(table_write.prepare_commit())
table_write.close()
table_commit.close()

table_scan = read_builder.new_scan()
table_read = read_builder.new_read()
actual_df0 = table_read.to_pandas(table_scan.plan().splits()).sort_values(by='f0')
df0['f0'] = df0['f0'].astype('int32')
pd.testing.assert_frame_equal(
actual_df0.reset_index(drop=True), df0.reset_index(drop=True))

write_builder = table.new_batch_write_builder().overwrite({'f0': '1'})
table_write = write_builder.new_write()
table_commit = write_builder.new_commit()

df1 = pd.DataFrame({
'f0': [1],
'f1': ['watermelon'],
})

table_write.write_pandas(df1)
table_commit.commit(table_write.prepare_commit())
table_write.close()
table_commit.close()

table_scan = read_builder.new_scan()
table_read = read_builder.new_read()
actual_df1 = table_read.to_pandas(table_scan.plan().splits())
expected_df1 = pd.DataFrame({
'f0': [2, 2, 1],
'f1': ['dog', 'cat', 'watermelon']
})
expected_df1['f0'] = expected_df1['f0'].astype('int32')
pd.testing.assert_frame_equal(
actual_df1.reset_index(drop=True), expected_df1.reset_index(drop=True))

write_builder = table.new_batch_write_builder().overwrite()
table_write = write_builder.new_write()
table_commit = write_builder.new_commit()

df2 = pd.DataFrame({
'f0': [3],
'f1': ['Neo'],
})

table_write.write_pandas(df2)
table_commit.commit(table_write.prepare_commit())
table_write.close()
table_commit.close()

table_scan = read_builder.new_scan()
table_read = read_builder.new_read()
actual_df2 = table_read.to_pandas(table_scan.plan().splits())
df2['f0'] = df2['f0'].astype('int32')
pd.testing.assert_frame_equal(
actual_df2.reset_index(drop=True), df2.reset_index(drop=True))

0 comments on commit 9fee9b6

Please sign in to comment.