Skip to content

Commit c5cbfe4

Browse files
committed
cherry-pick: Support Remove Branch or Tag APIs (#822)
1 parent 9b31690 commit c5cbfe4

File tree

3 files changed

+91
-0
lines changed

3 files changed

+91
-0
lines changed

pyiceberg/table/update/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,23 @@ def _(update: RemoveSnapshotsUpdate, base_metadata: TableMetadata, context: _Tab
472472
return base_metadata.model_copy(update={"snapshots": snapshots})
473473

474474

475+
@_apply_table_update.register(RemoveSnapshotRefUpdate)
476+
def _(update: RemoveSnapshotRefUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
477+
if (existing_ref := base_metadata.refs.get(update.ref_name, None)) is None:
478+
return base_metadata
479+
480+
if base_metadata.snapshot_by_id(existing_ref.snapshot_id) is None:
481+
raise ValueError(f"Cannot remove {update.ref_name} ref with unknown snapshot {existing_ref.snapshot_id}")
482+
483+
if update.ref_name == MAIN_BRANCH:
484+
raise ValueError("Cannot remove main branch")
485+
486+
metadata_refs = {**base_metadata.refs}
487+
metadata_refs.pop(update.ref_name, None)
488+
context.add_update(update)
489+
return base_metadata.model_copy(update={"refs": metadata_refs})
490+
491+
475492
@_apply_table_update.register(AddSortOrderUpdate)
476493
def _(update: AddSortOrderUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
477494
context.add_update(update)

pyiceberg/table/update/snapshot.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
from pyiceberg.table.update import (
6666
AddSnapshotUpdate,
6767
AssertRefSnapshotId,
68+
RemoveSnapshotRefUpdate,
6869
SetSnapshotRefUpdate,
6970
TableRequirement,
7071
TableUpdate,
@@ -749,6 +750,27 @@ def _commit(self) -> UpdatesAndRequirements:
749750
"""Apply the pending changes and commit."""
750751
return self._updates, self._requirements
751752

753+
def _remove_ref_snapshot(self, ref_name: str) -> ManageSnapshots:
754+
"""Remove a snapshot ref.
755+
Args:
756+
ref_name: branch / tag name to remove
757+
Stages the updates and requirements for the remove-snapshot-ref.
758+
Returns
759+
This method for chaining
760+
"""
761+
updates = (RemoveSnapshotRefUpdate(ref_name=ref_name),)
762+
requirements = (
763+
AssertRefSnapshotId(
764+
snapshot_id=self._transaction.table_metadata.refs[ref_name].snapshot_id
765+
if ref_name in self._transaction.table_metadata.refs
766+
else None,
767+
ref=ref_name,
768+
),
769+
)
770+
self._updates += updates
771+
self._requirements += requirements
772+
return self
773+
752774
def create_tag(self, snapshot_id: int, tag_name: str, max_ref_age_ms: Optional[int] = None) -> ManageSnapshots:
753775
"""
754776
Create a new tag pointing to the given snapshot id.
@@ -771,6 +793,16 @@ def create_tag(self, snapshot_id: int, tag_name: str, max_ref_age_ms: Optional[i
771793
self._requirements += requirement
772794
return self
773795

796+
def remove_tag(self, tag_name: str) -> ManageSnapshots:
797+
"""
798+
Remove a tag.
799+
Args:
800+
tag_name (str): name of tag to remove
801+
Returns:
802+
This for method chaining
803+
"""
804+
return self._remove_ref_snapshot(ref_name=tag_name)
805+
774806
def create_branch(
775807
self,
776808
snapshot_id: int,
@@ -802,3 +834,13 @@ def create_branch(
802834
self._updates += update
803835
self._requirements += requirement
804836
return self
837+
838+
def remove_branch(self, branch_name: str) -> ManageSnapshots:
839+
"""
840+
Remove a branch.
841+
Args:
842+
branch_name (str): name of branch to remove
843+
Returns:
844+
This for method chaining
845+
"""
846+
return self._remove_ref_snapshot(ref_name=branch_name)

tests/integration/test_snapshot_operations.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,35 @@ def test_create_branch(catalog: Catalog) -> None:
4040
branch_snapshot_id = tbl.history()[-2].snapshot_id
4141
tbl.manage_snapshots().create_branch(snapshot_id=branch_snapshot_id, branch_name="branch123").commit()
4242
assert tbl.metadata.refs["branch123"] == SnapshotRef(snapshot_id=branch_snapshot_id, snapshot_ref_type="branch")
43+
44+
45+
@pytest.mark.integration
46+
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
47+
def test_remove_tag(catalog: Catalog) -> None:
48+
identifier = "default.test_table_snapshot_operations"
49+
tbl = catalog.load_table(identifier)
50+
assert len(tbl.history()) > 3
51+
# first, create the tag to remove
52+
tag_name = "tag_to_remove"
53+
tag_snapshot_id = tbl.history()[-3].snapshot_id
54+
tbl.manage_snapshots().create_tag(snapshot_id=tag_snapshot_id, tag_name=tag_name).commit()
55+
assert tbl.metadata.refs[tag_name] == SnapshotRef(snapshot_id=tag_snapshot_id, snapshot_ref_type="tag")
56+
# now, remove the tag
57+
tbl.manage_snapshots().remove_tag(tag_name=tag_name).commit()
58+
assert tbl.metadata.refs.get(tag_name, None) is None
59+
60+
61+
@pytest.mark.integration
62+
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
63+
def test_remove_branch(catalog: Catalog) -> None:
64+
identifier = "default.test_table_snapshot_operations"
65+
tbl = catalog.load_table(identifier)
66+
assert len(tbl.history()) > 2
67+
# first, create the branch to remove
68+
branch_name = "branch_to_remove"
69+
branch_snapshot_id = tbl.history()[-2].snapshot_id
70+
tbl.manage_snapshots().create_branch(snapshot_id=branch_snapshot_id, branch_name=branch_name).commit()
71+
assert tbl.metadata.refs[branch_name] == SnapshotRef(snapshot_id=branch_snapshot_id, snapshot_ref_type="branch")
72+
# now, remove the branch
73+
tbl.manage_snapshots().remove_branch(branch_name=branch_name).commit()
74+
assert tbl.metadata.refs.get(branch_name, None) is None

0 commit comments

Comments
 (0)