Skip to content

Commit d69407c

Browse files
authored
Move writes to the transaction class (#571)
1 parent a892309 commit d69407c

File tree

2 files changed

+127
-61
lines changed

2 files changed

+127
-61
lines changed

pyiceberg/table/__init__.py

Lines changed: 99 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,100 @@ def update_snapshot(self, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> U
356356
"""
357357
return UpdateSnapshot(self, io=self._table.io, snapshot_properties=snapshot_properties)
358358

359+
def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None:
360+
"""
361+
Shorthand API for appending a PyArrow table to a table transaction.
362+
363+
Args:
364+
df: The Arrow dataframe that will be appended to overwrite the table
365+
snapshot_properties: Custom properties to be added to the snapshot summary
366+
"""
367+
try:
368+
import pyarrow as pa
369+
except ModuleNotFoundError as e:
370+
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e
371+
372+
if not isinstance(df, pa.Table):
373+
raise ValueError(f"Expected PyArrow table, got: {df}")
374+
375+
if len(self._table.spec().fields) > 0:
376+
raise ValueError("Cannot write to partitioned tables")
377+
378+
_check_schema_compatible(self._table.schema(), other_schema=df.schema)
379+
# cast if the two schemas are compatible but not equal
380+
table_arrow_schema = self._table.schema().as_arrow()
381+
if table_arrow_schema != df.schema:
382+
df = df.cast(table_arrow_schema)
383+
384+
with self.update_snapshot(snapshot_properties=snapshot_properties).fast_append() as update_snapshot:
385+
# skip writing data files if the dataframe is empty
386+
if df.shape[0] > 0:
387+
data_files = _dataframe_to_data_files(
388+
table_metadata=self._table.metadata, write_uuid=update_snapshot.commit_uuid, df=df, io=self._table.io
389+
)
390+
for data_file in data_files:
391+
update_snapshot.append_data_file(data_file)
392+
393+
def overwrite(
394+
self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_TRUE, snapshot_properties: Dict[str, str] = EMPTY_DICT
395+
) -> None:
396+
"""
397+
Shorthand for adding a table overwrite with a PyArrow table to the transaction.
398+
399+
Args:
400+
df: The Arrow dataframe that will be used to overwrite the table
401+
overwrite_filter: ALWAYS_TRUE when you overwrite all the data,
402+
or a boolean expression in case of a partial overwrite
403+
snapshot_properties: Custom properties to be added to the snapshot summary
404+
"""
405+
try:
406+
import pyarrow as pa
407+
except ModuleNotFoundError as e:
408+
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e
409+
410+
if not isinstance(df, pa.Table):
411+
raise ValueError(f"Expected PyArrow table, got: {df}")
412+
413+
if overwrite_filter != AlwaysTrue():
414+
raise NotImplementedError("Cannot overwrite a subset of a table")
415+
416+
if len(self._table.spec().fields) > 0:
417+
raise ValueError("Cannot write to partitioned tables")
418+
419+
_check_schema_compatible(self._table.schema(), other_schema=df.schema)
420+
# cast if the two schemas are compatible but not equal
421+
table_arrow_schema = self._table.schema().as_arrow()
422+
if table_arrow_schema != df.schema:
423+
df = df.cast(table_arrow_schema)
424+
425+
with self.update_snapshot(snapshot_properties=snapshot_properties).overwrite() as update_snapshot:
426+
# skip writing data files if the dataframe is empty
427+
if df.shape[0] > 0:
428+
data_files = _dataframe_to_data_files(
429+
table_metadata=self._table.metadata, write_uuid=update_snapshot.commit_uuid, df=df, io=self._table.io
430+
)
431+
for data_file in data_files:
432+
update_snapshot.append_data_file(data_file)
433+
434+
def add_files(self, file_paths: List[str]) -> None:
435+
"""
436+
Shorthand API for adding files as data files to the table transaction.
437+
438+
Args:
439+
file_paths: The list of full file paths to be added as data files to the table
440+
441+
Raises:
442+
FileNotFoundError: If the file does not exist.
443+
"""
444+
if self._table.name_mapping() is None:
445+
self.set_properties(**{TableProperties.DEFAULT_NAME_MAPPING: self._table.schema().name_mapping.model_dump_json()})
446+
with self.update_snapshot().fast_append() as update_snapshot:
447+
data_files = _parquet_files_to_data_files(
448+
table_metadata=self._table.metadata, file_paths=file_paths, io=self._table.io
449+
)
450+
for data_file in data_files:
451+
update_snapshot.append_data_file(data_file)
452+
359453
def update_spec(self) -> UpdateSpec:
360454
"""Create a new UpdateSpec to update the partitioning of the table.
361455
@@ -1219,32 +1313,8 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
12191313
df: The Arrow dataframe that will be appended to overwrite the table
12201314
snapshot_properties: Custom properties to be added to the snapshot summary
12211315
"""
1222-
try:
1223-
import pyarrow as pa
1224-
except ModuleNotFoundError as e:
1225-
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e
1226-
1227-
if not isinstance(df, pa.Table):
1228-
raise ValueError(f"Expected PyArrow table, got: {df}")
1229-
1230-
if len(self.spec().fields) > 0:
1231-
raise ValueError("Cannot write to partitioned tables")
1232-
1233-
_check_schema_compatible(self.schema(), other_schema=df.schema)
1234-
# cast if the two schemas are compatible but not equal
1235-
table_arrow_schema = self.schema().as_arrow()
1236-
if table_arrow_schema != df.schema:
1237-
df = df.cast(table_arrow_schema)
1238-
1239-
with self.transaction() as txn:
1240-
with txn.update_snapshot(snapshot_properties=snapshot_properties).fast_append() as update_snapshot:
1241-
# skip writing data files if the dataframe is empty
1242-
if df.shape[0] > 0:
1243-
data_files = _dataframe_to_data_files(
1244-
table_metadata=self.metadata, write_uuid=update_snapshot.commit_uuid, df=df, io=self.io
1245-
)
1246-
for data_file in data_files:
1247-
update_snapshot.append_data_file(data_file)
1316+
with self.transaction() as tx:
1317+
tx.append(df=df, snapshot_properties=snapshot_properties)
12481318

12491319
def overwrite(
12501320
self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_TRUE, snapshot_properties: Dict[str, str] = EMPTY_DICT
@@ -1258,35 +1328,8 @@ def overwrite(
12581328
or a boolean expression in case of a partial overwrite
12591329
snapshot_properties: Custom properties to be added to the snapshot summary
12601330
"""
1261-
try:
1262-
import pyarrow as pa
1263-
except ModuleNotFoundError as e:
1264-
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e
1265-
1266-
if not isinstance(df, pa.Table):
1267-
raise ValueError(f"Expected PyArrow table, got: {df}")
1268-
1269-
if overwrite_filter != AlwaysTrue():
1270-
raise NotImplementedError("Cannot overwrite a subset of a table")
1271-
1272-
if len(self.spec().fields) > 0:
1273-
raise ValueError("Cannot write to partitioned tables")
1274-
1275-
_check_schema_compatible(self.schema(), other_schema=df.schema)
1276-
# cast if the two schemas are compatible but not equal
1277-
table_arrow_schema = self.schema().as_arrow()
1278-
if table_arrow_schema != df.schema:
1279-
df = df.cast(table_arrow_schema)
1280-
1281-
with self.transaction() as txn:
1282-
with txn.update_snapshot(snapshot_properties=snapshot_properties).overwrite() as update_snapshot:
1283-
# skip writing data files if the dataframe is empty
1284-
if df.shape[0] > 0:
1285-
data_files = _dataframe_to_data_files(
1286-
table_metadata=self.metadata, write_uuid=update_snapshot.commit_uuid, df=df, io=self.io
1287-
)
1288-
for data_file in data_files:
1289-
update_snapshot.append_data_file(data_file)
1331+
with self.transaction() as tx:
1332+
tx.overwrite(df=df, overwrite_filter=overwrite_filter, snapshot_properties=snapshot_properties)
12901333

12911334
def add_files(self, file_paths: List[str]) -> None:
12921335
"""
@@ -1299,12 +1342,7 @@ def add_files(self, file_paths: List[str]) -> None:
12991342
FileNotFoundError: If the file does not exist.
13001343
"""
13011344
with self.transaction() as tx:
1302-
if self.name_mapping() is None:
1303-
tx.set_properties(**{TableProperties.DEFAULT_NAME_MAPPING: self.schema().name_mapping.model_dump_json()})
1304-
with tx.update_snapshot().fast_append() as update_snapshot:
1305-
data_files = _parquet_files_to_data_files(table_metadata=self.metadata, file_paths=file_paths, io=self.io)
1306-
for data_file in data_files:
1307-
update_snapshot.append_data_file(data_file)
1345+
tx.add_files(file_paths=file_paths)
13081346

13091347
def update_spec(self, case_sensitive: bool = True) -> UpdateSpec:
13101348
return UpdateSpec(Transaction(self, autocommit=True), case_sensitive=case_sensitive)

tests/integration/test_writes.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -832,3 +832,31 @@ def test_inspect_snapshots(
832832
continue
833833

834834
assert left == right, f"Difference in column {column}: {left} != {right}"
835+
836+
837+
@pytest.mark.integration
838+
def test_write_within_transaction(spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:
839+
identifier = "default.write_in_open_transaction"
840+
tbl = _create_table(session_catalog, identifier, {"format-version": "1"}, [])
841+
842+
def get_metadata_entries_count(identifier: str) -> int:
843+
return spark.sql(
844+
f"""
845+
SELECT *
846+
FROM {identifier}.metadata_log_entries
847+
"""
848+
).count()
849+
850+
# one metadata entry from table creation
851+
assert get_metadata_entries_count(identifier) == 1
852+
853+
# one more metadata entry from transaction
854+
with tbl.transaction() as tx:
855+
tx.set_properties({"test": "1"})
856+
tx.append(arrow_table_with_null)
857+
assert get_metadata_entries_count(identifier) == 2
858+
859+
# two more metadata entries added from two separate transactions
860+
tbl.transaction().set_properties({"test": "2"}).commit_transaction()
861+
tbl.append(arrow_table_with_null)
862+
assert get_metadata_entries_count(identifier) == 4

0 commit comments

Comments
 (0)