Skip to content

Commit cbb8cec

Browse files
committed
add merge_append
1 parent f0fc260 commit cbb8cec

File tree

2 files changed

+59
-10
lines changed

2 files changed

+59
-10
lines changed

pyiceberg/table/__init__.py

+49
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,44 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
428428
for data_file in data_files:
429429
update_snapshot.append_data_file(data_file)
430430

431+
def merge_append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None:
432+
"""
433+
Shorthand API for appending a PyArrow table to a table transaction.
434+
435+
Args:
436+
df: The Arrow dataframe that will be appended to overwrite the table
437+
snapshot_properties: Custom properties to be added to the snapshot summary
438+
"""
439+
try:
440+
import pyarrow as pa
441+
except ModuleNotFoundError as e:
442+
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e
443+
444+
if not isinstance(df, pa.Table):
445+
raise ValueError(f"Expected PyArrow table, got: {df}")
446+
447+
if unsupported_partitions := [
448+
field for field in self.table_metadata.spec().fields if not field.transform.supports_pyarrow_transform
449+
]:
450+
raise ValueError(
451+
f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}."
452+
)
453+
454+
_check_schema_compatible(self._table.schema(), other_schema=df.schema)
455+
# cast if the two schemas are compatible but not equal
456+
table_arrow_schema = self._table.schema().as_arrow()
457+
if table_arrow_schema != df.schema:
458+
df = df.cast(table_arrow_schema)
459+
460+
with self.update_snapshot(snapshot_properties=snapshot_properties).merge_append() as update_snapshot:
461+
# skip writing data files if the dataframe is empty
462+
if df.shape[0] > 0:
463+
data_files = _dataframe_to_data_files(
464+
table_metadata=self._table.metadata, write_uuid=update_snapshot.commit_uuid, df=df,
465+
io=self._table.io
466+
)
467+
for data_file in data_files:
468+
update_snapshot.append_data_file(data_file)
431469
def overwrite(
432470
self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_TRUE, snapshot_properties: Dict[str, str] = EMPTY_DICT
433471
) -> None:
@@ -1352,6 +1390,17 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
13521390
with self.transaction() as tx:
13531391
tx.append(df=df, snapshot_properties=snapshot_properties)
13541392

1393+
def merge_append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None:
1394+
"""
1395+
Shorthand API for appending a PyArrow table to the table.
1396+
1397+
Args:
1398+
df: The Arrow dataframe that will be appended to overwrite the table
1399+
snapshot_properties: Custom properties to be added to the snapshot summary
1400+
"""
1401+
with self.transaction() as tx:
1402+
tx.merge_append(df=df, snapshot_properties=snapshot_properties)
1403+
13551404
def overwrite(
13561405
self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_TRUE, snapshot_properties: Dict[str, str] = EMPTY_DICT
13571406
) -> None:

tests/integration/test_writes/test_writes.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -876,7 +876,7 @@ def table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with_null
876876
@pytest.mark.integration
877877
@pytest.mark.parametrize("format_version", [1, 2])
878878
def test_merge_manifest_min_count_to_merge(
879-
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int
879+
session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int
880880
) -> None:
881881
tbl_a = _create_table(
882882
session_catalog,
@@ -898,19 +898,19 @@ def test_merge_manifest_min_count_to_merge(
898898
)
899899

900900
# tbl_a should merge all manifests into 1
901-
tbl_a.append(arrow_table_with_null)
902-
tbl_a.append(arrow_table_with_null)
903-
tbl_a.append(arrow_table_with_null)
901+
tbl_a.merge_append(arrow_table_with_null)
902+
tbl_a.merge_append(arrow_table_with_null)
903+
tbl_a.merge_append(arrow_table_with_null)
904904

905905
# tbl_b should not merge any manifests because the target size is too small
906-
tbl_b.append(arrow_table_with_null)
907-
tbl_b.append(arrow_table_with_null)
908-
tbl_b.append(arrow_table_with_null)
906+
tbl_b.merge_append(arrow_table_with_null)
907+
tbl_b.merge_append(arrow_table_with_null)
908+
tbl_b.merge_append(arrow_table_with_null)
909909

910910
# tbl_c should not merge any manifests because merging is disabled
911-
tbl_c.append(arrow_table_with_null)
912-
tbl_c.append(arrow_table_with_null)
913-
tbl_c.append(arrow_table_with_null)
911+
tbl_c.merge_append(arrow_table_with_null)
912+
tbl_c.merge_append(arrow_table_with_null)
913+
tbl_c.merge_append(arrow_table_with_null)
914914

915915
assert len(tbl_a.current_snapshot().manifests(tbl_a.io)) == 1 # type: ignore
916916
assert len(tbl_b.current_snapshot().manifests(tbl_b.io)) == 3 # type: ignore

0 commit comments

Comments
 (0)