Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Strict projection #539

Merged
merged 3 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions pyiceberg/expressions/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1433,6 +1433,30 @@ def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool
return ROWS_MIGHT_MATCH


def strict_projection(
schema: Schema, spec: PartitionSpec, case_sensitive: bool = True
) -> Callable[[BooleanExpression], BooleanExpression]:
return StrictProjection(schema, spec, case_sensitive).project


class StrictProjection(ProjectionEvaluator):
def visit_bound_predicate(self, predicate: BoundPredicate[Any]) -> BooleanExpression:
parts = self.spec.fields_by_source_id(predicate.term.ref().field.field_id)

result: BooleanExpression = AlwaysFalse()
for part in parts:
# consider (ts > 2019-01-01T01:00:00) with day(ts) and hour(ts)
# projections: d >= 2019-01-02 and h >= 2019-01-01-02 (note the inclusive bounds).
# any timestamp where either projection predicate is true must match the original
# predicate. For example, ts = 2019-01-01T03:00:00 matches the hour projection but not
# the day, but does match the original predicate.
strict_projection = part.transform.strict_project(name=part.name, pred=predicate)
if strict_projection is not None:
result = Or(result, strict_projection)

return result


class _StrictMetricsEvaluator(_MetricsEvaluator):
struct: StructType
expr: BooleanExpression
Expand Down
140 changes: 137 additions & 3 deletions pyiceberg/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
BoundLessThan,
BoundLessThanOrEqual,
BoundLiteralPredicate,
BoundNotEqualTo,
BoundNotIn,
BoundNotStartsWith,
BoundPredicate,
Expand All @@ -43,8 +44,11 @@
BoundTerm,
BoundUnaryPredicate,
EqualTo,
GreaterThan,
GreaterThanOrEqual,
LessThan,
LessThanOrEqual,
NotEqualTo,
NotStartsWith,
Reference,
StartsWith,
Expand Down Expand Up @@ -144,6 +148,9 @@ def result_type(self, source: IcebergType) -> IcebergType: ...
@abstractmethod
def project(self, name: str, pred: BoundPredicate[L]) -> Optional[UnboundPredicate[Any]]: ...

@abstractmethod
def strict_project(self, name: str, pred: BoundPredicate[Any]) -> Optional[UnboundPredicate[Any]]: ...

@property
def preserves_order(self) -> bool:
return False
Expand Down Expand Up @@ -216,6 +223,21 @@ def project(self, name: str, pred: BoundPredicate[L]) -> Optional[UnboundPredica
# For example, (x > 0) and (x < 3) can be turned into in({1, 2}) and projected.
return None

def strict_project(self, name: str, pred: BoundPredicate[Any]) -> Optional[UnboundPredicate[Any]]:
transformer = self.transform(pred.term.ref().field.field_type)

if isinstance(pred.term, BoundTransform):
return _project_transform_predicate(self, name, pred)
elif isinstance(pred, BoundUnaryPredicate):
return pred.as_unbound(Reference(name))
elif isinstance(pred, BoundNotEqualTo):
return pred.as_unbound(Reference(name), _transform_literal(transformer, pred.literal))
elif isinstance(pred, BoundNotIn):
return pred.as_unbound(Reference(name), {_transform_literal(transformer, literal) for literal in pred.literals})
else:
# no strict projection for comparison or equality
return None

def can_transform(self, source: IcebergType) -> bool:
return isinstance(
source,
Expand Down Expand Up @@ -306,6 +328,19 @@ def project(self, name: str, pred: BoundPredicate[L]) -> Optional[UnboundPredica
else:
return None

def strict_project(self, name: str, pred: BoundPredicate[Any]) -> Optional[UnboundPredicate[Any]]:
transformer = self.transform(pred.term.ref().field.field_type)
if isinstance(pred.term, BoundTransform):
return _project_transform_predicate(self, name, pred)
elif isinstance(pred, BoundUnaryPredicate):
return pred.as_unbound(Reference(name))
elif isinstance(pred, BoundLiteralPredicate):
return _truncate_number_strict(name, pred, transformer)
elif isinstance(pred, BoundNotIn):
return _set_apply_transform(name, pred, transformer)
else:
return None

@property
def dedup_name(self) -> str:
return "time"
Expand Down Expand Up @@ -516,10 +551,20 @@ def project(self, name: str, pred: BoundPredicate[L]) -> Optional[UnboundPredica
return pred.as_unbound(Reference(name))
elif isinstance(pred, BoundLiteralPredicate):
return pred.as_unbound(Reference(name), pred.literal)
elif isinstance(pred, (BoundIn, BoundNotIn)):
elif isinstance(pred, BoundSetPredicate):
return pred.as_unbound(Reference(name), pred.literals)
else:
return None

def strict_project(self, name: str, pred: BoundPredicate[Any]) -> Optional[UnboundPredicate[Any]]:
if isinstance(pred, BoundUnaryPredicate):
return pred.as_unbound(Reference(name))
elif isinstance(pred, BoundLiteralPredicate):
return pred.as_unbound(Reference(name), pred.literal)
elif isinstance(pred, BoundSetPredicate):
return pred.as_unbound(Reference(name), pred.literals)
else:
raise ValueError(f"Could not project: {pred}")
return None

@property
def preserves_order(self) -> bool:
Expand Down Expand Up @@ -590,6 +635,47 @@ def project(self, name: str, pred: BoundPredicate[L]) -> Optional[UnboundPredica
return _truncate_array(name, pred, self.transform(field_type))
return None

def strict_project(self, name: str, pred: BoundPredicate[Any]) -> Optional[UnboundPredicate[Any]]:
field_type = pred.term.ref().field.field_type

if isinstance(pred.term, BoundTransform):
return _project_transform_predicate(self, name, pred)

if isinstance(field_type, (IntegerType, LongType, DecimalType)):
if isinstance(pred, BoundUnaryPredicate):
return pred.as_unbound(Reference(name))
elif isinstance(pred, BoundLiteralPredicate):
return _truncate_number_strict(name, pred, self.transform(field_type))
elif isinstance(pred, BoundNotIn):
return _set_apply_transform(name, pred, self.transform(field_type))
else:
return None

if isinstance(pred, BoundLiteralPredicate):
if isinstance(pred, BoundStartsWith):
literal_width = len(pred.literal.value)
if literal_width < self.width:
return pred.as_unbound(name, pred.literal.value)
elif literal_width == self.width:
return EqualTo(name, pred.literal.value)
else:
return None
elif isinstance(pred, BoundNotStartsWith):
literal_width = len(pred.literal.value)
if literal_width < self.width:
return pred.as_unbound(name, pred.literal.value)
elif literal_width == self.width:
return NotEqualTo(name, pred.literal.value)
else:
return pred.as_unbound(name, self.transform(field_type)(pred.literal.value))
else:
# ProjectionUtil.truncateArrayStrict(name, pred, this);
return _truncate_array_strict(name, pred, self.transform(field_type))
elif isinstance(pred, BoundNotIn):
return _set_apply_transform(name, pred, self.transform(field_type))
else:
return None

@property
def width(self) -> int:
return self._width
Expand Down Expand Up @@ -714,6 +800,9 @@ def result_type(self, source: IcebergType) -> StringType:
def project(self, name: str, pred: BoundPredicate[L]) -> Optional[UnboundPredicate[Any]]:
return None

def strict_project(self, name: str, pred: BoundPredicate[Any]) -> Optional[UnboundPredicate[Any]]:
return None

def __repr__(self) -> str:
"""Return the string representation of the UnknownTransform class."""
return f"UnknownTransform(transform={repr(self._transform)})"
Expand All @@ -736,6 +825,9 @@ def result_type(self, source: IcebergType) -> IcebergType:
def project(self, name: str, pred: BoundPredicate[L]) -> Optional[UnboundPredicate[Any]]:
return None

def strict_project(self, name: str, pred: BoundPredicate[L]) -> Optional[UnboundPredicate[Any]]:
return None

def to_human_string(self, _: IcebergType, value: Optional[S]) -> str:
return "null"

Expand Down Expand Up @@ -766,6 +858,47 @@ def _truncate_number(
return None


def _truncate_number_strict(
name: str, pred: BoundLiteralPredicate[L], transform: Callable[[Optional[L]], Optional[L]]
) -> Optional[UnboundPredicate[Any]]:
boundary = pred.literal

if not isinstance(boundary, (LongLiteral, DecimalLiteral, DateLiteral, TimestampLiteral)):
raise ValueError(f"Expected a numeric literal, got: {type(boundary)}")

if isinstance(pred, BoundLessThan):
return LessThan(Reference(name), _transform_literal(transform, boundary))
elif isinstance(pred, BoundLessThanOrEqual):
return LessThan(Reference(name), _transform_literal(transform, boundary.increment())) # type: ignore
elif isinstance(pred, BoundGreaterThan):
return GreaterThan(Reference(name), _transform_literal(transform, boundary))
elif isinstance(pred, BoundGreaterThanOrEqual):
return GreaterThan(Reference(name), _transform_literal(transform, boundary.decrement())) # type: ignore
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I'm confused about decrement here. @Fokko
E.g. col1 is date type and transform is month, for the predicate is col1 >= 1970-01-02, this will cause the predicate project to month(col1) > -1, e.g. the col1 is 1970-01-01, 1970-01-01 >= 1970-01-02 is false but
month(1970-01-01) = 0 > -1 is true.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, you might be onto something here:

_test_projection(
transform.strict_project(name="name", pred=BoundGreaterThanOrEqual(term=bound_reference_date, literal=date)),
GreaterThan(term="name", literal=LongLiteral(-1)), # In Java this is human string 1970-01
)

https://github.com/apache/iceberg/blob/d402f83fc7b224b21242c506cf503e5bcbc8c867/api/src/test/java/org/apache/iceberg/transforms/TestDatesProjection.java#L139-L140

Luckily this code is not yet used, I'll write a patch tomorrow 👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @ZENOTME Again, thanks for pointing this out. It turns out that most of these oddities are caused by a bug that was part of Iceberg ≤0.10.0. We did not port this to PyIceberg: https://github.com/apache/iceberg/blob/ac6509a4e469f808bebb8b713a5c4213f98ff4a5/api/src/main/java/org/apache/iceberg/transforms/ProjectionUtil.java#L275

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created a PR to refactor the tests: #1422

elif isinstance(pred, BoundNotEqualTo):
return EqualTo(Reference(name), _transform_literal(transform, boundary))
elif isinstance(pred, BoundEqualTo):
# there is no predicate that guarantees equality because adjacent longs transform to the
# same value
return None
else:
return None


def _truncate_array_strict(
name: str, pred: BoundLiteralPredicate[L], transform: Callable[[Optional[L]], Optional[L]]
) -> Optional[UnboundPredicate[Any]]:
boundary = pred.literal

if isinstance(pred, (BoundLessThan, BoundLessThanOrEqual)):
return LessThan(Reference(name), _transform_literal(transform, boundary))
elif isinstance(pred, (BoundGreaterThan, BoundGreaterThanOrEqual)):
return GreaterThan(Reference(name), _transform_literal(transform, boundary))
if isinstance(pred, BoundNotEqualTo):
return NotEqualTo(Reference(name), _transform_literal(transform, boundary))
else:
return None


def _truncate_array(
name: str, pred: BoundLiteralPredicate[L], transform: Callable[[Optional[L]], Optional[L]]
) -> Optional[UnboundPredicate[Any]]:
Expand Down Expand Up @@ -808,7 +941,8 @@ def _remove_transform(partition_name: str, pred: BoundPredicate[L]) -> UnboundPr
def _set_apply_transform(name: str, pred: BoundSetPredicate[L], transform: Callable[[L], L]) -> UnboundPredicate[Any]:
literals = pred.literals
if isinstance(pred, BoundSetPredicate):
return pred.as_unbound(Reference(name), {_transform_literal(transform, literal) for literal in literals})
transformed_literals = {_transform_literal(transform, literal) for literal in literals}
return pred.as_unbound(Reference(name=name), literals=transformed_literals)
else:
raise ValueError(f"Unknown BoundSetPredicate: {pred}")

Expand Down
11 changes: 11 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
NestedField,
StringType,
StructType,
UUIDType,
)
from pyiceberg.utils.datetime import datetime_to_millis

Expand Down Expand Up @@ -1928,6 +1929,16 @@ def bound_reference_str() -> BoundReference[str]:
return BoundReference(field=NestedField(1, "field", StringType(), required=False), accessor=Accessor(position=0, inner=None))


@pytest.fixture
def bound_reference_binary() -> BoundReference[str]:
return BoundReference(field=NestedField(1, "field", BinaryType(), required=False), accessor=Accessor(position=0, inner=None))


@pytest.fixture
def bound_reference_uuid() -> BoundReference[str]:
return BoundReference(field=NestedField(1, "field", UUIDType(), required=False), accessor=Accessor(position=0, inner=None))


@pytest.fixture(scope="session")
def session_catalog() -> Catalog:
return load_catalog(
Expand Down
Loading