Skip to content

Commit

Permalink
Add Support for Dynamic Overwrite (#931)
Browse files Browse the repository at this point in the history
  • Loading branch information
jqin61 authored Dec 19, 2024
1 parent 2529c81 commit 952d7c0
Show file tree
Hide file tree
Showing 5 changed files with 633 additions and 29 deletions.
121 changes: 121 additions & 0 deletions mkdocs/docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,127 @@ lat: [[52.371807,37.773972,53.11254],[53.21917]]
long: [[4.896029,-122.431297,6.0989],[6.56667]]
```

### Partial overwrites

When using the `overwrite` API, you can use an `overwrite_filter` to delete data that matches the filter before appending new data into the table.

For example, with an iceberg table created as:

```python
from pyiceberg.catalog import load_catalog
catalog = load_catalog("default")
from pyiceberg.schema import Schema
from pyiceberg.types import NestedField, StringType, DoubleType
schema = Schema(
NestedField(1, "city", StringType(), required=False),
NestedField(2, "lat", DoubleType(), required=False),
NestedField(3, "long", DoubleType(), required=False),
)
tbl = catalog.create_table("default.cities", schema=schema)
```

And with initial data populating the table:

```python
import pyarrow as pa
df = pa.Table.from_pylist(
[
{"city": "Amsterdam", "lat": 52.371807, "long": 4.896029},
{"city": "San Francisco", "lat": 37.773972, "long": -122.431297},
{"city": "Drachten", "lat": 53.11254, "long": 6.0989},
{"city": "Paris", "lat": 48.864716, "long": 2.349014},
],
)
tbl.append(df)
```

You can overwrite the record of `Paris` with a record of `New York`:

```python
from pyiceberg.expressions import EqualTo
df = pa.Table.from_pylist(
[
{"city": "New York", "lat": 40.7128, "long": 74.0060},
]
)
tbl.overwrite(df, overwrite_filter=EqualTo('city', "Paris"))
```

This produces the following result with `tbl.scan().to_arrow()`:

```python
pyarrow.Table
city: large_string
lat: double
long: double
----
city: [["New York"],["Amsterdam","San Francisco","Drachten"]]
lat: [[40.7128],[52.371807,37.773972,53.11254]]
long: [[74.006],[4.896029,-122.431297,6.0989]]
```

If the PyIceberg table is partitioned, you can use `tbl.dynamic_partition_overwrite(df)` to replace the existing partitions with new ones provided in the dataframe. The partitions to be replaced are detected automatically from the provided arrow table.
For example, with an iceberg table with a partition specified on `"city"` field:

```python
from pyiceberg.schema import Schema
from pyiceberg.types import DoubleType, NestedField, StringType
schema = Schema(
NestedField(1, "city", StringType(), required=False),
NestedField(2, "lat", DoubleType(), required=False),
NestedField(3, "long", DoubleType(), required=False),
)
tbl = catalog.create_table(
"default.cities",
schema=schema,
partition_spec=PartitionSpec(PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="city_identity"))
)
```

And we want to overwrite the data for the partition of `"Paris"`:

```python
import pyarrow as pa
df = pa.Table.from_pylist(
[
{"city": "Amsterdam", "lat": 52.371807, "long": 4.896029},
{"city": "San Francisco", "lat": 37.773972, "long": -122.431297},
{"city": "Drachten", "lat": 53.11254, "long": 6.0989},
{"city": "Paris", "lat": -48.864716, "long": -2.349014},
],
)
tbl.append(df)
```

Then we can call `dynamic_partition_overwrite` with this arrow table:

```python
df_corrected = pa.Table.from_pylist([
{"city": "Paris", "lat": 48.864716, "long": 2.349014}
])
tbl.dynamic_partition_overwrite(df_corrected)
```

This produces the following result with `tbl.scan().to_arrow()`:

```python
pyarrow.Table
city: large_string
lat: double
long: double
----
city: [["Paris"],["Amsterdam"],["Drachten"],["San Francisco"]]
lat: [[48.864716],[52.371807],[53.11254],[37.773972]]
long: [[2.349014],[4.896029],[6.0989],[-122.431297]]
```

## Inspecting tables

To explore the table metadata, tables can be inspected.
Expand Down
4 changes: 4 additions & 0 deletions pyiceberg/expressions/literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,10 @@ def _(self, _: TimeType) -> Literal[int]:
def _(self, _: TimestampType) -> Literal[int]:
return TimestampLiteral(self.value)

@to.register(TimestamptzType)
def _(self, _: TimestamptzType) -> Literal[int]:
return TimestampLiteral(self.value)

@to.register(DecimalType)
def _(self, type_var: DecimalType) -> Literal[Decimal]:
unscaled = Decimal(self.value)
Expand Down
1 change: 0 additions & 1 deletion pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2519,7 +2519,6 @@ def _check_pyarrow_schema_compatible(
raise ValueError(
f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)."
) from e

_check_schema_compatible(requested_schema, provided_schema)


Expand Down
135 changes: 120 additions & 15 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,14 @@

import pyiceberg.expressions.parser as parser
from pyiceberg.expressions import (
AlwaysFalse,
AlwaysTrue,
And,
BooleanExpression,
EqualTo,
IsNull,
Or,
Reference,
)
from pyiceberg.expressions.visitors import (
_InclusiveMetricsEvaluator,
Expand Down Expand Up @@ -117,6 +121,7 @@
_OverwriteFiles,
)
from pyiceberg.table.update.spec import UpdateSpec
from pyiceberg.transforms import IdentityTransform
from pyiceberg.typedef import (
EMPTY_DICT,
IcebergBaseModel,
Expand Down Expand Up @@ -344,6 +349,48 @@ def _set_ref_snapshot(

return updates, requirements

def _build_partition_predicate(self, partition_records: Set[Record]) -> BooleanExpression:
"""Build a filter predicate matching any of the input partition records.
Args:
partition_records: A set of partition records to match
Returns:
A predicate matching any of the input partition records.
"""
partition_spec = self.table_metadata.spec()
schema = self.table_metadata.schema()
partition_fields = [schema.find_field(field.source_id).name for field in partition_spec.fields]

expr: BooleanExpression = AlwaysFalse()
for partition_record in partition_records:
match_partition_expression: BooleanExpression = AlwaysTrue()

for pos, partition_field in enumerate(partition_fields):
predicate = (
EqualTo(Reference(partition_field), partition_record[pos])
if partition_record[pos] is not None
else IsNull(Reference(partition_field))
)
match_partition_expression = And(match_partition_expression, predicate)
expr = Or(expr, match_partition_expression)
return expr

def _append_snapshot_producer(self, snapshot_properties: Dict[str, str]) -> _FastAppendFiles:
"""Determine the append type based on table properties.
Args:
snapshot_properties: Custom properties to be added to the snapshot summary
Returns:
Either a fast-append or a merge-append snapshot producer.
"""
manifest_merge_enabled = property_as_bool(
self.table_metadata.properties,
TableProperties.MANIFEST_MERGE_ENABLED,
TableProperties.MANIFEST_MERGE_ENABLED_DEFAULT,
)
update_snapshot = self.update_snapshot(snapshot_properties=snapshot_properties)
return update_snapshot.merge_append() if manifest_merge_enabled else update_snapshot.fast_append()

def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive: bool = True) -> UpdateSchema:
"""Create a new UpdateSchema to alter the columns of this table.
Expand Down Expand Up @@ -398,15 +445,7 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
self.table_metadata.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)

manifest_merge_enabled = property_as_bool(
self.table_metadata.properties,
TableProperties.MANIFEST_MERGE_ENABLED,
TableProperties.MANIFEST_MERGE_ENABLED_DEFAULT,
)
update_snapshot = self.update_snapshot(snapshot_properties=snapshot_properties)
append_method = update_snapshot.merge_append if manifest_merge_enabled else update_snapshot.fast_append

with append_method() as append_files:
with self._append_snapshot_producer(snapshot_properties) as append_files:
# skip writing data files if the dataframe is empty
if df.shape[0] > 0:
data_files = _dataframe_to_data_files(
Expand All @@ -415,6 +454,62 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
for data_file in data_files:
append_files.append_data_file(data_file)

def dynamic_partition_overwrite(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None:
"""
Shorthand for overwriting existing partitions with a PyArrow table.
The function detects partition values in the provided arrow table using the current
partition spec, and deletes existing partitions matching these values. Finally, the
data in the table is appended to the table.
Args:
df: The Arrow dataframe that will be used to overwrite the table
snapshot_properties: Custom properties to be added to the snapshot summary
"""
try:
import pyarrow as pa
except ModuleNotFoundError as e:
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e

from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible, _dataframe_to_data_files

if not isinstance(df, pa.Table):
raise ValueError(f"Expected PyArrow table, got: {df}")

if self.table_metadata.spec().is_unpartitioned():
raise ValueError("Cannot apply dynamic overwrite on an unpartitioned table.")

for field in self.table_metadata.spec().fields:
if not isinstance(field.transform, IdentityTransform):
raise ValueError(
f"For now dynamic overwrite does not support a table with non-identity-transform field in the latest partition spec: {field}"
)

downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
_check_pyarrow_schema_compatible(
self.table_metadata.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)

# If dataframe does not have data, there is no need to overwrite
if df.shape[0] == 0:
return

append_snapshot_commit_uuid = uuid.uuid4()
data_files: List[DataFile] = list(
_dataframe_to_data_files(
table_metadata=self._table.metadata, write_uuid=append_snapshot_commit_uuid, df=df, io=self._table.io
)
)

partitions_to_overwrite = {data_file.partition for data_file in data_files}
delete_filter = self._build_partition_predicate(partition_records=partitions_to_overwrite)
self.delete(delete_filter=delete_filter, snapshot_properties=snapshot_properties)

with self._append_snapshot_producer(snapshot_properties) as append_files:
append_files.commit_uuid = append_snapshot_commit_uuid
for data_file in data_files:
append_files.append_data_file(data_file)

def overwrite(
self,
df: pa.Table,
Expand Down Expand Up @@ -461,14 +556,14 @@ def overwrite(

self.delete(delete_filter=overwrite_filter, case_sensitive=case_sensitive, snapshot_properties=snapshot_properties)

with self.update_snapshot(snapshot_properties=snapshot_properties).fast_append() as update_snapshot:
with self._append_snapshot_producer(snapshot_properties) as append_files:
# skip writing data files if the dataframe is empty
if df.shape[0] > 0:
data_files = _dataframe_to_data_files(
table_metadata=self.table_metadata, write_uuid=update_snapshot.commit_uuid, df=df, io=self._table.io
table_metadata=self.table_metadata, write_uuid=append_files.commit_uuid, df=df, io=self._table.io
)
for data_file in data_files:
update_snapshot.append_data_file(data_file)
append_files.append_data_file(data_file)

def delete(
self,
Expand Down Expand Up @@ -552,9 +647,8 @@ def delete(
))

if len(replaced_files) > 0:
with self.update_snapshot(snapshot_properties=snapshot_properties).overwrite(
commit_uuid=commit_uuid
) as overwrite_snapshot:
with self.update_snapshot(snapshot_properties=snapshot_properties).overwrite() as overwrite_snapshot:
overwrite_snapshot.commit_uuid = commit_uuid
for original_data_file, replaced_data_files in replaced_files:
overwrite_snapshot.delete_data_file(original_data_file)
for replaced_data_file in replaced_data_files:
Expand Down Expand Up @@ -989,6 +1083,17 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
with self.transaction() as tx:
tx.append(df=df, snapshot_properties=snapshot_properties)

def dynamic_partition_overwrite(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None:
"""Shorthand for dynamic overwriting the table with a PyArrow table.
Old partitions are auto detected and replaced with data files created for input arrow table.
Args:
df: The Arrow dataframe that will be used to overwrite the table
snapshot_properties: Custom properties to be added to the snapshot summary
"""
with self.transaction() as tx:
tx.dynamic_partition_overwrite(df=df, snapshot_properties=snapshot_properties)

def overwrite(
self,
df: pa.Table,
Expand Down
Loading

0 comments on commit 952d7c0

Please sign in to comment.