diff --git a/pyiceberg/table/update/__init__.py b/pyiceberg/table/update/__init__.py index 935a105047..fd8c7cac1c 100644 --- a/pyiceberg/table/update/__init__.py +++ b/pyiceberg/table/update/__init__.py @@ -466,6 +466,22 @@ def _(update: SetSnapshotRefUpdate, base_metadata: TableMetadata, context: _Tabl return base_metadata.model_copy(update=metadata_updates) +@_apply_table_update.register(RemoveSnapshotRefUpdate) +def _(update: RemoveSnapshotRefUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: + if update.ref_name not in base_metadata.refs: + return base_metadata + + existing_ref = base_metadata.refs[update.ref_name] + if base_metadata.snapshot_by_id(existing_ref.snapshot_id) is None: + raise ValueError(f"Cannot remove {update.ref_name} ref with unknown snapshot {existing_ref.snapshot_id}") + + current_snapshot_id = None if update.ref_name == MAIN_BRANCH else base_metadata.current_snapshot_id + + metadata_refs = {ref_name: ref for ref_name, ref in base_metadata.refs.items() if ref_name != update.ref_name} + context.add_update(update) + return base_metadata.model_copy(update={"refs": metadata_refs, "current_snapshot_id": current_snapshot_id}) + + @_apply_table_update.register(AddSortOrderUpdate) def _(update: AddSortOrderUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: context.add_update(update) diff --git a/pyiceberg/table/update/snapshot.py b/pyiceberg/table/update/snapshot.py index c0d0056e7c..c1bf46566f 100644 --- a/pyiceberg/table/update/snapshot.py +++ b/pyiceberg/table/update/snapshot.py @@ -65,6 +65,7 @@ from pyiceberg.table.update import ( AddSnapshotUpdate, AssertRefSnapshotId, + RemoveSnapshotRefUpdate, SetSnapshotRefUpdate, TableRequirement, TableUpdate, @@ -749,6 +750,28 @@ def _commit(self) -> UpdatesAndRequirements: """Apply the pending changes and commit.""" return self._updates, self._requirements + def _remove_ref_snapshot(self, ref_name: str) -> ManageSnapshots: + """Remove a snapshot ref. + + Args: + ref_name: branch / tag name to remove + Stages the updates and requirements for the remove-snapshot-ref. + Returns + This method for chaining + """ + updates = (RemoveSnapshotRefUpdate(ref_name=ref_name),) + requirements = ( + AssertRefSnapshotId( + snapshot_id=self._transaction.table_metadata.refs[ref_name].snapshot_id + if ref_name in self._transaction.table_metadata.refs + else None, + ref=ref_name, + ), + ) + self._updates += updates + self._requirements += requirements + return self + def create_tag(self, snapshot_id: int, tag_name: str, max_ref_age_ms: Optional[int] = None) -> ManageSnapshots: """ Create a new tag pointing to the given snapshot id. @@ -771,6 +794,17 @@ def create_tag(self, snapshot_id: int, tag_name: str, max_ref_age_ms: Optional[i self._requirements += requirement return self + def remove_tag(self, tag_name: str) -> ManageSnapshots: + """ + Remove a tag. + + Args: + tag_name (str): name of tag to remove + Returns: + This for method chaining + """ + return self._remove_ref_snapshot(ref_name=tag_name) + def create_branch( self, snapshot_id: int, @@ -802,3 +836,14 @@ def create_branch( self._updates += update self._requirements += requirement return self + + def remove_branch(self, branch_name: str) -> ManageSnapshots: + """ + Remove a branch. + + Args: + branch_name (str): name of branch to remove + Returns: + This for method chaining + """ + return self._remove_ref_snapshot(ref_name=branch_name) diff --git a/tests/integration/test_snapshot_operations.py b/tests/integration/test_snapshot_operations.py index 639193383e..1b7f2d3a29 100644 --- a/tests/integration/test_snapshot_operations.py +++ b/tests/integration/test_snapshot_operations.py @@ -40,3 +40,35 @@ def test_create_branch(catalog: Catalog) -> None: branch_snapshot_id = tbl.history()[-2].snapshot_id tbl.manage_snapshots().create_branch(snapshot_id=branch_snapshot_id, branch_name="branch123").commit() assert tbl.metadata.refs["branch123"] == SnapshotRef(snapshot_id=branch_snapshot_id, snapshot_ref_type="branch") + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_remove_tag(catalog: Catalog) -> None: + identifier = "default.test_table_snapshot_operations" + tbl = catalog.load_table(identifier) + assert len(tbl.history()) > 3 + # first, create the tag to remove + tag_name = "tag_to_remove" + tag_snapshot_id = tbl.history()[-3].snapshot_id + tbl.manage_snapshots().create_tag(snapshot_id=tag_snapshot_id, tag_name=tag_name).commit() + assert tbl.metadata.refs[tag_name] == SnapshotRef(snapshot_id=tag_snapshot_id, snapshot_ref_type="tag") + # now, remove the tag + tbl.manage_snapshots().remove_tag(tag_name=tag_name).commit() + assert tbl.metadata.refs.get(tag_name, None) is None + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_remove_branch(catalog: Catalog) -> None: + identifier = "default.test_table_snapshot_operations" + tbl = catalog.load_table(identifier) + assert len(tbl.history()) > 2 + # first, create the branch to remove + branch_name = "branch_to_remove" + branch_snapshot_id = tbl.history()[-2].snapshot_id + tbl.manage_snapshots().create_branch(snapshot_id=branch_snapshot_id, branch_name=branch_name).commit() + assert tbl.metadata.refs[branch_name] == SnapshotRef(snapshot_id=branch_snapshot_id, snapshot_ref_type="branch") + # now, remove the branch + tbl.manage_snapshots().remove_branch(branch_name=branch_name).commit() + assert tbl.metadata.refs.get(branch_name, None) is None diff --git a/tests/table/test_init.py b/tests/table/test_init.py index 521cc5e46f..4836c7bbad 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -78,6 +78,7 @@ AssertRefSnapshotId, AssertTableUUID, RemovePropertiesUpdate, + RemoveSnapshotRefUpdate, RemoveStatisticsUpdate, SetDefaultSortOrderUpdate, SetPropertiesUpdate, @@ -793,6 +794,15 @@ def test_update_metadata_set_snapshot_ref(table_v2: Table) -> None: ) +def test_update_remove_snapshots(table_v2: Table) -> None: + # assert fixture data to easily understand the test assumptions + assert len(table_v2.metadata.refs) == 2 + update = RemoveSnapshotRefUpdate(ref_name="test") + new_metadata = update_table_metadata(table_v2.metadata, (update,)) + assert len(new_metadata.refs) == 1 + assert new_metadata.refs["main"].snapshot_id == 3055729675574597004 + + def test_update_metadata_add_update_sort_order(table_v2: Table) -> None: new_sort_order = SortOrder(order_id=table_v2.sort_order().order_id + 1) new_metadata = update_table_metadata(