diff --git a/pyiceberg/table/update/__init__.py b/pyiceberg/table/update/__init__.py index 3cf2db630d..f4570f05cd 100644 --- a/pyiceberg/table/update/__init__.py +++ b/pyiceberg/table/update/__init__.py @@ -177,7 +177,6 @@ class RemovePropertiesUpdate(IcebergBaseModel): class SetStatisticsUpdate(IcebergBaseModel): action: Literal["set-statistics"] = Field(default="set-statistics") - snapshot_id: int = Field(alias="snapshot-id") statistics: StatisticsFile @@ -491,10 +490,7 @@ def _( @_apply_table_update.register(SetStatisticsUpdate) def _(update: SetStatisticsUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: - if update.snapshot_id != update.statistics.snapshot_id: - raise ValueError("Snapshot id in statistics does not match the snapshot id in the update") - - statistics = filter_statistics_by_snapshot_id(base_metadata.statistics, update.snapshot_id) + statistics = filter_statistics_by_snapshot_id(base_metadata.statistics, update.statistics.snapshot_id) context.add_update(update) return base_metadata.model_copy(update={"statistics": statistics + [update.statistics]}) diff --git a/pyiceberg/table/update/statistics.py b/pyiceberg/table/update/statistics.py index e31025453b..f5604a6ce7 100644 --- a/pyiceberg/table/update/statistics.py +++ b/pyiceberg/table/update/statistics.py @@ -52,10 +52,9 @@ class UpdateStatistics(UpdateTableMetadata["UpdateStatistics"]): def __init__(self, transaction: "Transaction") -> None: super().__init__(transaction) - def set_statistics(self, snapshot_id: int, statistics_file: StatisticsFile) -> "UpdateStatistics": + def set_statistics(self, statistics_file: StatisticsFile) -> "UpdateStatistics": self._updates += ( SetStatisticsUpdate( - snapshot_id=snapshot_id, statistics=statistics_file, ), ) diff --git a/tests/integration/test_statistics_operations.py b/tests/integration/test_statistics_operations.py index 361bfebb63..a7b4e38802 100644 --- a/tests/integration/test_statistics_operations.py +++ b/tests/integration/test_statistics_operations.py @@ -73,8 +73,8 @@ def create_statistics_file(snapshot_id: int, type_name: str) -> StatisticsFile: statistics_file_snap_2 = create_statistics_file(add_snapshot_id_2, "deletion-vector-v1") with tbl.update_statistics() as update: - update.set_statistics(add_snapshot_id_1, statistics_file_snap_1) - update.set_statistics(add_snapshot_id_2, statistics_file_snap_2) + update.set_statistics(statistics_file_snap_1) + update.set_statistics(statistics_file_snap_2) assert len(tbl.metadata.statistics) == 2 diff --git a/tests/table/test_init.py b/tests/table/test_init.py index e1f2ccc876..521cc5e46f 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -1310,20 +1310,6 @@ def test_set_statistics_update(table_v2_with_statistics: Table) -> None: assert len(updated_statistics) == 1 assert json.loads(updated_statistics[0].model_dump_json()) == json.loads(expected) - update = SetStatisticsUpdate( - snapshot_id=123456789, - statistics=statistics_file, - ) - - with pytest.raises( - ValueError, - match="Snapshot id in statistics does not match the snapshot id in the update", - ): - update_table_metadata( - table_v2_with_statistics.metadata, - (update,), - ) - def test_remove_statistics_update(table_v2_with_statistics: Table) -> None: update = RemoveStatisticsUpdate(