Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move writes to Transaction #571

Merged
merged 2 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 99 additions & 61 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,100 @@ def update_snapshot(self, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> U
"""
return UpdateSnapshot(self, io=self._table.io, snapshot_properties=snapshot_properties)

def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None:
"""
Shorthand API for appending a PyArrow table to a table transaction.

Args:
df: The Arrow dataframe that will be appended to overwrite the table
snapshot_properties: Custom properties to be added to the snapshot summary
"""
try:
import pyarrow as pa
except ModuleNotFoundError as e:
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e

if not isinstance(df, pa.Table):
raise ValueError(f"Expected PyArrow table, got: {df}")

if len(self._table.spec().fields) > 0:
raise ValueError("Cannot write to partitioned tables")

_check_schema_compatible(self._table.schema(), other_schema=df.schema)
# cast if the two schemas are compatible but not equal
table_arrow_schema = self._table.schema().as_arrow()
if table_arrow_schema != df.schema:
df = df.cast(table_arrow_schema)

with self.update_snapshot(snapshot_properties=snapshot_properties).fast_append() as update_snapshot:
# skip writing data files if the dataframe is empty
if df.shape[0] > 0:
data_files = _dataframe_to_data_files(
table_metadata=self._table.metadata, write_uuid=update_snapshot.commit_uuid, df=df, io=self._table.io
)
for data_file in data_files:
update_snapshot.append_data_file(data_file)

def overwrite(
self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_TRUE, snapshot_properties: Dict[str, str] = EMPTY_DICT
) -> None:
"""
Shorthand for adding a table overwrite with a PyArrow table to the transaction.

Args:
df: The Arrow dataframe that will be used to overwrite the table
overwrite_filter: ALWAYS_TRUE when you overwrite all the data,
or a boolean expression in case of a partial overwrite
snapshot_properties: Custom properties to be added to the snapshot summary
"""
try:
import pyarrow as pa
except ModuleNotFoundError as e:
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e

if not isinstance(df, pa.Table):
raise ValueError(f"Expected PyArrow table, got: {df}")

if overwrite_filter != AlwaysTrue():
raise NotImplementedError("Cannot overwrite a subset of a table")

if len(self._table.spec().fields) > 0:
raise ValueError("Cannot write to partitioned tables")

_check_schema_compatible(self._table.schema(), other_schema=df.schema)
# cast if the two schemas are compatible but not equal
table_arrow_schema = self._table.schema().as_arrow()
if table_arrow_schema != df.schema:
df = df.cast(table_arrow_schema)

with self.update_snapshot(snapshot_properties=snapshot_properties).overwrite() as update_snapshot:
# skip writing data files if the dataframe is empty
if df.shape[0] > 0:
data_files = _dataframe_to_data_files(
table_metadata=self._table.metadata, write_uuid=update_snapshot.commit_uuid, df=df, io=self._table.io
)
for data_file in data_files:
update_snapshot.append_data_file(data_file)

def add_files(self, file_paths: List[str]) -> None:
"""
Shorthand API for adding files as data files to the table transaction.

Args:
file_paths: The list of full file paths to be added as data files to the table

Raises:
FileNotFoundError: If the file does not exist.
"""
if self._table.name_mapping() is None:
self.set_properties(**{TableProperties.DEFAULT_NAME_MAPPING: self._table.schema().name_mapping.model_dump_json()})
with self.update_snapshot().fast_append() as update_snapshot:
data_files = _parquet_files_to_data_files(
table_metadata=self._table.metadata, file_paths=file_paths, io=self._table.io
)
for data_file in data_files:
update_snapshot.append_data_file(data_file)

def update_spec(self) -> UpdateSpec:
"""Create a new UpdateSpec to update the partitioning of the table.

Expand Down Expand Up @@ -1124,32 +1218,8 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
df: The Arrow dataframe that will be appended to overwrite the table
snapshot_properties: Custom properties to be added to the snapshot summary
"""
try:
import pyarrow as pa
except ModuleNotFoundError as e:
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e

if not isinstance(df, pa.Table):
raise ValueError(f"Expected PyArrow table, got: {df}")

if len(self.spec().fields) > 0:
raise ValueError("Cannot write to partitioned tables")

_check_schema_compatible(self.schema(), other_schema=df.schema)
# cast if the two schemas are compatible but not equal
table_arrow_schema = self.schema().as_arrow()
if table_arrow_schema != df.schema:
df = df.cast(table_arrow_schema)

with self.transaction() as txn:
with txn.update_snapshot(snapshot_properties=snapshot_properties).fast_append() as update_snapshot:
# skip writing data files if the dataframe is empty
if df.shape[0] > 0:
data_files = _dataframe_to_data_files(
table_metadata=self.metadata, write_uuid=update_snapshot.commit_uuid, df=df, io=self.io
)
for data_file in data_files:
update_snapshot.append_data_file(data_file)
with self.transaction() as tx:
tx.append(df=df, snapshot_properties=snapshot_properties)

def overwrite(
self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_TRUE, snapshot_properties: Dict[str, str] = EMPTY_DICT
Expand All @@ -1163,35 +1233,8 @@ def overwrite(
or a boolean expression in case of a partial overwrite
snapshot_properties: Custom properties to be added to the snapshot summary
"""
try:
import pyarrow as pa
except ModuleNotFoundError as e:
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e

if not isinstance(df, pa.Table):
raise ValueError(f"Expected PyArrow table, got: {df}")

if overwrite_filter != AlwaysTrue():
raise NotImplementedError("Cannot overwrite a subset of a table")

if len(self.spec().fields) > 0:
raise ValueError("Cannot write to partitioned tables")

_check_schema_compatible(self.schema(), other_schema=df.schema)
# cast if the two schemas are compatible but not equal
table_arrow_schema = self.schema().as_arrow()
if table_arrow_schema != df.schema:
df = df.cast(table_arrow_schema)

with self.transaction() as txn:
with txn.update_snapshot(snapshot_properties=snapshot_properties).overwrite() as update_snapshot:
# skip writing data files if the dataframe is empty
if df.shape[0] > 0:
data_files = _dataframe_to_data_files(
table_metadata=self.metadata, write_uuid=update_snapshot.commit_uuid, df=df, io=self.io
)
for data_file in data_files:
update_snapshot.append_data_file(data_file)
with self.transaction() as tx:
tx.overwrite(df=df, overwrite_filter=overwrite_filter, snapshot_properties=snapshot_properties)

def add_files(self, file_paths: List[str]) -> None:
"""
Expand All @@ -1204,12 +1247,7 @@ def add_files(self, file_paths: List[str]) -> None:
FileNotFoundError: If the file does not exist.
"""
with self.transaction() as tx:
if self.name_mapping() is None:
tx.set_properties(**{TableProperties.DEFAULT_NAME_MAPPING: self.schema().name_mapping.model_dump_json()})
with tx.update_snapshot().fast_append() as update_snapshot:
data_files = _parquet_files_to_data_files(table_metadata=self.metadata, file_paths=file_paths, io=self.io)
for data_file in data_files:
update_snapshot.append_data_file(data_file)
tx.add_files(file_paths=file_paths)

def update_spec(self, case_sensitive: bool = True) -> UpdateSpec:
return UpdateSpec(Transaction(self, autocommit=True), case_sensitive=case_sensitive)
Expand Down
28 changes: 28 additions & 0 deletions tests/integration/test_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,3 +778,31 @@ def test_inspect_snapshots(
continue

assert left == right, f"Difference in column {column}: {left} != {right}"


@pytest.mark.integration
def test_write_within_transaction(spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we cover overwrite in the test as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure thing :)

identifier = "default.write_in_open_transaction"
tbl = _create_table(session_catalog, identifier, {"format-version": "1"}, [])

def get_metadata_entries_count(identifier: str) -> int:
return spark.sql(
f"""
SELECT *
FROM {identifier}.metadata_log_entries
"""
).count()

# one metadata entry from table creation
assert get_metadata_entries_count(identifier) == 1

# one more metadata entry from transaction
with tbl.transaction() as tx:
tx.set_properties({"test": "1"})
tx.append(arrow_table_with_null)
assert get_metadata_entries_count(identifier) == 2

# two more metadata entries added from two separate transactions
tbl.transaction().set_properties({"test": "2"}).commit_transaction()
tbl.append(arrow_table_with_null)
assert get_metadata_entries_count(identifier) == 4
Loading