Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions pyiceberg/expressions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,10 @@ def __eq__(self, other: Any) -> bool:
@abstractmethod
def as_unbound(self) -> Type[UnboundPredicate[Any]]: ...

def __hash__(self) -> int:
"""Return hash value of the BoundPredicate class."""
return hash(str(self))


class UnboundPredicate(Generic[L], Unbound[BooleanExpression], BooleanExpression, ABC):
term: UnboundTerm[Any]
Expand All @@ -369,6 +373,10 @@ def bind(self, schema: Schema, case_sensitive: bool = True) -> BooleanExpression
@abstractmethod
def as_bound(self) -> Type[BoundPredicate[L]]: ...

def __hash__(self) -> int:
"""Return hash value of the UnaryPredicate class."""
return hash(str(self))


class UnaryPredicate(UnboundPredicate[Any], ABC):
def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundUnaryPredicate[Any]:
Expand Down Expand Up @@ -698,6 +706,10 @@ def __repr__(self) -> str:
@abstractmethod
def as_bound(self) -> Type[BoundLiteralPredicate[L]]: ...

def __hash__(self) -> int:
"""Return hash value of the LiteralPredicate class."""
return hash(str(self))


class BoundLiteralPredicate(BoundPredicate[L], ABC):
literal: Literal[L]
Expand All @@ -721,6 +733,10 @@ def __repr__(self) -> str:
@abstractmethod
def as_unbound(self) -> Type[LiteralPredicate[L]]: ...

def __hash__(self) -> int:
"""Return hash value of the BoundLiteralPredicate class."""
return hash(str(self))


class BoundEqualTo(BoundLiteralPredicate[L]):
def __invert__(self) -> BoundNotEqualTo[L]:
Expand Down
5 changes: 4 additions & 1 deletion pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1776,7 +1776,10 @@ def write_parquet(task: WriteTask) -> DataFile:
fo = io.new_output(file_path)
with fo.create(overwrite=True) as fos:
with pq.ParquetWriter(fos, schema=arrow_file_schema, **parquet_writer_kwargs) as writer:
writer.write(pa.Table.from_batches(task.record_batches), row_group_size=row_group_size)
arrow_table = pa.Table.from_batches(task.record_batches)
# align the columns accordingly in case input arrow table has columns in order different from iceberg table
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you provide an example of when this would happen? This only handles top-level columns.

df_to_write = arrow_table.select(arrow_file_schema.names)
writer.write_table(df_to_write, row_group_size=row_group_size)

statistics = data_file_statistics_from_parquet_metadata(
parquet_metadata=writer.writer.metadata,
Expand Down
109 changes: 109 additions & 0 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,14 @@
from pyiceberg.conversions import from_bytes
from pyiceberg.exceptions import CommitFailedException, ResolveError, ValidationError
from pyiceberg.expressions import (
AlwaysFalse,
AlwaysTrue,
And,
BooleanExpression,
BoundEqualTo,
BoundIsNull,
EqualTo,
IsNull,
Reference,
)
from pyiceberg.expressions.visitors import (
Expand Down Expand Up @@ -123,6 +127,7 @@
IcebergRootModel,
Identifier,
KeyDefaultDict,
L,
Properties,
Record,
TableVersion,
Expand Down Expand Up @@ -423,11 +428,19 @@ def overwrite(
raise ValueError("Cannot write to partitioned tables")

_check_schema_compatible(self._table.schema(), other_schema=df.schema)

# cast if the two schemas are compatible but not equal
table_arrow_schema = self._table.schema().as_arrow()
if table_arrow_schema != df.schema:
df = df.cast(table_arrow_schema)

# will be used when we support partitioned overwrite
# if not overwrite_filter == ALWAYS_TRUE:
# bound_is_null_predicates, bound_eq_to_predicates = _check_static_overwrite_filter_compatible(
# table_schema=self.schema(), overwrite_filter=overwrite_filter, spec=self.metadata.spec()
# )
# # predicates will be provided to update_snapshot().overwrite(), mismatch error might be raised when partitions in the df do not match expected partition in the filter

with self.update_snapshot(snapshot_properties=snapshot_properties).overwrite() as update_snapshot:
# skip writing data files if the dataframe is empty
if df.shape[0] > 0:
Expand Down Expand Up @@ -3524,3 +3537,99 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T
table_partitions: list[TablePartition] = _get_table_partitions(arrow_table, spec, schema, slice_instructions)

return table_partitions


def _bind_and_validate_static_overwrite_filter_predicate(
unbound_expr: Union[IsNull, EqualTo[L]], table_schema: Schema, spec: PartitionSpec
) -> Union[BoundIsNull[L], BoundEqualTo[L]]:
# step 1: check whether the expr is a sensible one which could bind to table schema.
# For example, a field not existing in the table would crash the binding.
bound_expr: Union[BoundIsNull[L], BoundEqualTo[L], AlwaysFalse] = unbound_expr.bind(table_schema) # type: ignore # The bind returns upcast types.

# step 2: check non nullable column is not partitioned overwriten with isNull.
if isinstance(bound_expr, AlwaysFalse):
raise ValueError(f"Static overwrite on a non-nullable partition field with null values: {unbound_expr}.")

# step 3: check the unbound_expr is within the partition spec
# if not isinstance(bound_expr, (BoundIsNull, BoundEqualTo)):
# raise ValueError(
# f"Expecting static overwrite filter with IsNull or EqualTo concatenated by And. But get: {unbound_expr}."
# )
nested_field: NestedField = bound_expr.term.ref().field
part_fields: List[PartitionField] = spec.fields_by_source_id(nested_field.field_id)
if len(part_fields) == 0:
raise ValueError(f"Detected that the field of ({nested_field}) in static overwrite filter is not a partition field.")

# step 4: check the unbound_expr is with identity transform
# for loop is for an edge case where a same identity partition field is duplicated in the partition spec
for part_field in part_fields:
if not isinstance(part_field.transform, IdentityTransform):
raise ValueError(
f"Expecting static overwrite filter only to on fields with identity transform, but get transform: ({part_fields[0].transform}) for field: ({nested_field})."
)

return bound_expr


def _check_static_overwrite_filter_compatible(
table_schema: Schema, overwrite_filter: BooleanExpression, spec: PartitionSpec
) -> Tuple[Set[BoundIsNull[L]], Set[BoundEqualTo[L]]]:
is_null_predicates, eq_to_predicates = _validate_static_overwrite_filter_expr_type(expr=overwrite_filter)

bound_is_null_preds = set()
bound_eq_to_preds = set()
for unbound_is_null in is_null_predicates:
bound_pred = _bind_and_validate_static_overwrite_filter_predicate(
unbound_expr=unbound_is_null, table_schema=table_schema, spec=spec
)
if not isinstance(bound_pred, BoundIsNull):
raise ValueError(f"Expecting IsNull after binding {unbound_is_null} to schema but get {bound_pred}.")
bound_is_null_preds.add(bound_pred)

for unbound_eq_to in eq_to_predicates:
bound_pred = _bind_and_validate_static_overwrite_filter_predicate(
unbound_expr=unbound_eq_to, table_schema=table_schema, spec=spec
)
if not isinstance(bound_pred, BoundEqualTo):
raise ValueError(f"Expecting IsNull after binding {unbound_eq_to} to schema but get {bound_pred}.")
bound_eq_to_preds.add(bound_pred)
return (bound_is_null_preds, bound_eq_to_preds) # type: ignore


def _validate_static_overwrite_filter_expr_type(expr: BooleanExpression) -> Tuple[Set[IsNull], Set[EqualTo[L]]]:
"""Validate whether expression only has 1)And 2)IsNull and 3)EqualTo and break down the raw expression into IsNull and EqualTo."""
from collections import defaultdict

def _recursively_fetch_fields(
expr: BooleanExpression, is_null_predicates: Set[IsNull], eq_to_predicates: Set[EqualTo[L]]
) -> None:
if isinstance(expr, EqualTo):
if not isinstance(expr.term, Reference):
raise ValueError(f"Unsupported unbound term {expr.term} in {expr}, expecting a refernce.")
duplication_check[expr.term.name].add(expr)
eq_to_predicates.add(expr)
elif isinstance(expr, IsNull):
if not isinstance(expr.term, Reference):
raise ValueError(f"Unsupported unbound term {expr.term} in {expr}, expecting a refernce.")
duplication_check[expr.term.name].add(expr)
is_null_predicates.add(expr)
elif isinstance(expr, And):
_recursively_fetch_fields(expr.left, is_null_predicates, eq_to_predicates)
_recursively_fetch_fields(expr.right, is_null_predicates, eq_to_predicates)
else:
raise ValueError(
f"static overwrite partitioning filter can only be isequalto, is null, and, alwaysTrue, but get {expr=}"
)

duplication_check: Dict[str, Set[Union[IsNull, EqualTo[L]]]] = defaultdict(set)
is_null_predicates: Set[IsNull] = set()
eq_to_predicates: Set[EqualTo[L]] = set()
_recursively_fetch_fields(expr, is_null_predicates, eq_to_predicates)
for _, expr_set in duplication_check.items():
if len(expr_set) != 1:
raise ValueError(
f"static overwrite partitioning filter has more than 1 different predicates with same field {expr_set}"
)

# check fields don't step into itself, and do not step into each other, maybe we could move this to other 1+3(here) fields check
return is_null_predicates, eq_to_predicates
Loading