diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 9e9de52dee..7cb3626782 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -115,11 +115,7 @@ update_table_metadata, ) from pyiceberg.table.update.schema import UpdateSchema -from pyiceberg.table.update.snapshot import ( - ManageSnapshots, - UpdateSnapshot, - _FastAppendFiles, -) +from pyiceberg.table.update.snapshot import ExpireSnapshots, ManageSnapshots, UpdateSnapshot, _FastAppendFiles from pyiceberg.table.update.spec import UpdateSpec from pyiceberg.table.update.statistics import UpdateStatistics from pyiceberg.transforms import IdentityTransform @@ -1079,6 +1075,10 @@ def manage_snapshots(self) -> ManageSnapshots: """ return ManageSnapshots(transaction=Transaction(self, autocommit=True)) + def expire_snapshots(self) -> ExpireSnapshots: + """Shorthand to run expire snapshots by id or by a timestamp.""" + return ExpireSnapshots(transaction=Transaction(self, autocommit=True)) + def update_statistics(self) -> UpdateStatistics: """ Shorthand to run statistics management operations like add statistics and remove statistics. diff --git a/pyiceberg/table/update/snapshot.py b/pyiceberg/table/update/snapshot.py index b53c331758..5383cd5c9e 100644 --- a/pyiceberg/table/update/snapshot.py +++ b/pyiceberg/table/update/snapshot.py @@ -55,6 +55,7 @@ from pyiceberg.partitioning import ( PartitionSpec, ) +from pyiceberg.table.refs import SnapshotRefType from pyiceberg.table.snapshots import ( Operation, Snapshot, @@ -66,6 +67,7 @@ AddSnapshotUpdate, AssertRefSnapshotId, RemoveSnapshotRefUpdate, + RemoveSnapshotsUpdate, SetSnapshotRefUpdate, TableRequirement, TableUpdate, @@ -843,3 +845,103 @@ def remove_branch(self, branch_name: str) -> ManageSnapshots: This for method chaining """ return self._remove_ref_snapshot(ref_name=branch_name) + + +class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]): + """ + Expire snapshots by ID. + + Use table.expire_snapshots().().commit() to run a specific operation. + Use table.expire_snapshots().().().commit() to run multiple operations. + Pending changes are applied on commit. + """ + + _snapshot_ids_to_expire: Set[int] = set() + _updates: Tuple[TableUpdate, ...] = () + _requirements: Tuple[TableRequirement, ...] = () + + def _commit(self) -> UpdatesAndRequirements: + """ + Commit the staged updates and requirements. + + This will remove the snapshots with the given IDs, but will always skip protected snapshots (branch/tag heads). + + Returns: + Tuple of updates and requirements to be committed, + as required by the calling parent apply functions. + """ + # Remove any protected snapshot IDs from the set to expire, just in case + protected_ids = self._get_protected_snapshot_ids() + self._snapshot_ids_to_expire -= protected_ids + update = RemoveSnapshotsUpdate(snapshot_ids=self._snapshot_ids_to_expire) + self._updates += (update,) + return self._updates, self._requirements + + def _get_protected_snapshot_ids(self) -> Set[int]: + """ + Get the IDs of protected snapshots. + + These are the HEAD snapshots of all branches and all tagged snapshots. These ids are to be excluded from expiration. + + Returns: + Set of protected snapshot IDs to exclude from expiration. + """ + protected_ids: Set[int] = set() + + for ref in self._transaction.table_metadata.refs.values(): + if ref.snapshot_ref_type in [SnapshotRefType.TAG, SnapshotRefType.BRANCH]: + protected_ids.add(ref.snapshot_id) + + return protected_ids + + def expire_snapshot_by_id(self, snapshot_id: int) -> ExpireSnapshots: + """ + Expire a snapshot by its ID. + + This will mark the snapshot for expiration. + + Args: + snapshot_id (int): The ID of the snapshot to expire. + Returns: + This for method chaining. + """ + if self._transaction.table_metadata.snapshot_by_id(snapshot_id) is None: + raise ValueError(f"Snapshot with ID {snapshot_id} does not exist.") + + if snapshot_id in self._get_protected_snapshot_ids(): + raise ValueError(f"Snapshot with ID {snapshot_id} is protected and cannot be expired.") + + self._snapshot_ids_to_expire.add(snapshot_id) + + return self + + def expire_snapshots_by_ids(self, snapshot_ids: List[int]) -> "ExpireSnapshots": + """ + Expire multiple snapshots by their IDs. + + This will mark the snapshots for expiration. + + Args: + snapshot_ids (List[int]): List of snapshot IDs to expire. + Returns: + This for method chaining. + """ + for snapshot_id in snapshot_ids: + self.expire_snapshot_by_id(snapshot_id) + return self + + def expire_snapshots_older_than(self, timestamp_ms: int) -> "ExpireSnapshots": + """ + Expire all unprotected snapshots with a timestamp older than a given value. + + Args: + timestamp_ms (int): Only snapshots with timestamp_ms < this value will be expired. + + Returns: + This for method chaining. + """ + protected_ids = self._get_protected_snapshot_ids() + for snapshot in self._transaction.table_metadata.snapshots: + if snapshot.timestamp_ms < timestamp_ms and snapshot.snapshot_id not in protected_ids: + self._snapshot_ids_to_expire.add(snapshot.snapshot_id) + return self diff --git a/tests/table/test_expire_snapshots.py b/tests/table/test_expire_snapshots.py new file mode 100644 index 0000000000..82ecb9e493 --- /dev/null +++ b/tests/table/test_expire_snapshots.py @@ -0,0 +1,224 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from unittest.mock import MagicMock +from uuid import uuid4 + +import pytest + +from pyiceberg.table import CommitTableResponse, Table + + +def test_cannot_expire_protected_head_snapshot(table_v2: Table) -> None: + """Test that a HEAD (branch) snapshot cannot be expired.""" + HEAD_SNAPSHOT = 3051729675574597004 + KEEP_SNAPSHOT = 3055729675574597004 + + # Mock the catalog's commit_table method + table_v2.catalog = MagicMock() + # Simulate refs protecting HEAD_SNAPSHOT as a branch + table_v2.metadata = table_v2.metadata.model_copy( + update={ + "refs": { + "main": MagicMock(snapshot_id=HEAD_SNAPSHOT, snapshot_ref_type="branch"), + "tag1": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="tag"), + } + } + ) + # Assert fixture data + assert any(ref.snapshot_id == HEAD_SNAPSHOT for ref in table_v2.metadata.refs.values()) + + # Attempt to expire the HEAD snapshot and expect a ValueError + with pytest.raises(ValueError, match=f"Snapshot with ID {HEAD_SNAPSHOT} is protected and cannot be expired."): + table_v2.expire_snapshots().expire_snapshot_by_id(HEAD_SNAPSHOT).commit() + + table_v2.catalog.commit_table.assert_not_called() + + +def test_cannot_expire_tagged_snapshot(table_v2: Table) -> None: + """Test that a tagged snapshot cannot be expired.""" + TAGGED_SNAPSHOT = 3051729675574597004 + KEEP_SNAPSHOT = 3055729675574597004 + + table_v2.catalog = MagicMock() + # Simulate refs protecting TAGGED_SNAPSHOT as a tag + table_v2.metadata = table_v2.metadata.model_copy( + update={ + "refs": { + "tag1": MagicMock(snapshot_id=TAGGED_SNAPSHOT, snapshot_ref_type="tag"), + "main": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="branch"), + } + } + ) + assert any(ref.snapshot_id == TAGGED_SNAPSHOT for ref in table_v2.metadata.refs.values()) + + with pytest.raises(ValueError, match=f"Snapshot with ID {TAGGED_SNAPSHOT} is protected and cannot be expired."): + table_v2.expire_snapshots().expire_snapshot_by_id(TAGGED_SNAPSHOT).commit() + + table_v2.catalog.commit_table.assert_not_called() + + +def test_expire_unprotected_snapshot(table_v2: Table) -> None: + """Test that an unprotected snapshot can be expired.""" + EXPIRE_SNAPSHOT = 3051729675574597004 + KEEP_SNAPSHOT = 3055729675574597004 + + mock_response = CommitTableResponse( + metadata=table_v2.metadata.model_copy(update={"snapshots": [KEEP_SNAPSHOT]}), + metadata_location="mock://metadata/location", + uuid=uuid4(), + ) + table_v2.catalog = MagicMock() + table_v2.catalog.commit_table.return_value = mock_response + + # Remove any refs that protect the snapshot to be expired + table_v2.metadata = table_v2.metadata.model_copy( + update={ + "refs": { + "main": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="branch"), + "tag1": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="tag"), + } + } + ) + + # Assert fixture data + assert all(ref.snapshot_id != EXPIRE_SNAPSHOT for ref in table_v2.metadata.refs.values()) + + # Expire the snapshot + table_v2.expire_snapshots().expire_snapshot_by_id(EXPIRE_SNAPSHOT).commit() + + table_v2.catalog.commit_table.assert_called_once() + remaining_snapshots = table_v2.metadata.snapshots + assert EXPIRE_SNAPSHOT not in remaining_snapshots + assert len(table_v2.metadata.snapshots) == 1 + + +def test_expire_nonexistent_snapshot_raises(table_v2: Table) -> None: + """Test that trying to expire a non-existent snapshot raises an error.""" + NONEXISTENT_SNAPSHOT = 9999999999999999999 + + table_v2.catalog = MagicMock() + table_v2.metadata = table_v2.metadata.model_copy(update={"refs": {}}) + + with pytest.raises(ValueError, match=f"Snapshot with ID {NONEXISTENT_SNAPSHOT} does not exist."): + table_v2.expire_snapshots().expire_snapshot_by_id(NONEXISTENT_SNAPSHOT).commit() + + table_v2.catalog.commit_table.assert_not_called() + + +def test_expire_snapshots_by_timestamp_skips_protected(table_v2: Table) -> None: + # Setup: two snapshots; both are old, but one is head/tag protected + HEAD_SNAPSHOT = 3051729675574597004 + TAGGED_SNAPSHOT = 3055729675574597004 + + # Add snapshots to metadata for timestamp/protected test + from types import SimpleNamespace + + table_v2.metadata = table_v2.metadata.model_copy( + update={ + "refs": { + "main": MagicMock(snapshot_id=HEAD_SNAPSHOT, snapshot_ref_type="branch"), + "mytag": MagicMock(snapshot_id=TAGGED_SNAPSHOT, snapshot_ref_type="tag"), + }, + "snapshots": [ + SimpleNamespace(snapshot_id=HEAD_SNAPSHOT, timestamp_ms=1, parent_snapshot_id=None), + SimpleNamespace(snapshot_id=TAGGED_SNAPSHOT, timestamp_ms=1, parent_snapshot_id=None), + ], + } + ) + table_v2.catalog = MagicMock() + + # Attempt to expire all snapshots before a future timestamp (so both are candidates) + future_timestamp = 9999999999999 # Far in the future, after any real snapshot + + # Mock the catalog's commit_table to return the current metadata (simulate no change) + mock_response = CommitTableResponse( + metadata=table_v2.metadata, # protected snapshots remain + metadata_location="mock://metadata/location", + uuid=uuid4(), + ) + table_v2.catalog.commit_table.return_value = mock_response + + table_v2.expire_snapshots().expire_snapshots_older_than(future_timestamp).commit() + # Update metadata to reflect the commit (as in other tests) + table_v2.metadata = mock_response.metadata + + # Both protected snapshots should remain + remaining_ids = {s.snapshot_id for s in table_v2.metadata.snapshots} + assert HEAD_SNAPSHOT in remaining_ids + assert TAGGED_SNAPSHOT in remaining_ids + + # No snapshots should have been expired (commit_table called, but with empty snapshot_ids) + args, kwargs = table_v2.catalog.commit_table.call_args + updates = args[2] if len(args) > 2 else () + # Find RemoveSnapshotsUpdate in updates + remove_update = next((u for u in updates if getattr(u, "action", None) == "remove-snapshots"), None) + assert remove_update is not None + assert remove_update.snapshot_ids == [] + + +def test_expire_snapshots_by_ids(table_v2: Table) -> None: + """Test that multiple unprotected snapshots can be expired by IDs.""" + EXPIRE_SNAPSHOT_1 = 3051729675574597004 + EXPIRE_SNAPSHOT_2 = 3051729675574597005 + KEEP_SNAPSHOT = 3055729675574597004 + + mock_response = CommitTableResponse( + metadata=table_v2.metadata.model_copy(update={"snapshots": [KEEP_SNAPSHOT]}), + metadata_location="mock://metadata/location", + uuid=uuid4(), + ) + table_v2.catalog = MagicMock() + table_v2.catalog.commit_table.return_value = mock_response + + # Remove any refs that protect the snapshots to be expired + table_v2.metadata = table_v2.metadata.model_copy( + update={ + "refs": { + "main": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="branch"), + "tag1": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="tag"), + } + } + ) + + # Add snapshots to metadata for multi-id test + from types import SimpleNamespace + + table_v2.metadata = table_v2.metadata.model_copy( + update={ + "refs": { + "main": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="branch"), + "tag1": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="tag"), + }, + "snapshots": [ + SimpleNamespace(snapshot_id=EXPIRE_SNAPSHOT_1, timestamp_ms=1, parent_snapshot_id=None), + SimpleNamespace(snapshot_id=EXPIRE_SNAPSHOT_2, timestamp_ms=1, parent_snapshot_id=None), + SimpleNamespace(snapshot_id=KEEP_SNAPSHOT, timestamp_ms=2, parent_snapshot_id=None), + ], + } + ) + + # Assert fixture data + assert all(ref.snapshot_id not in (EXPIRE_SNAPSHOT_1, EXPIRE_SNAPSHOT_2) for ref in table_v2.metadata.refs.values()) + + # Expire the snapshots + table_v2.expire_snapshots().expire_snapshots_by_ids([EXPIRE_SNAPSHOT_1, EXPIRE_SNAPSHOT_2]).commit() + + table_v2.catalog.commit_table.assert_called_once() + remaining_snapshots = table_v2.metadata.snapshots + assert EXPIRE_SNAPSHOT_1 not in remaining_snapshots + assert EXPIRE_SNAPSHOT_2 not in remaining_snapshots + assert len(table_v2.metadata.snapshots) == 1