From 96c6e73a49bc7f3cebff2479451404c8537f72b7 Mon Sep 17 00:00:00 2001 From: Alex Stephen Date: Wed, 18 Jun 2025 13:37:25 -0700 Subject: [PATCH 1/2] encryption key --- pyiceberg/table/encryption.py | 27 +++++++++++++ pyiceberg/table/metadata.py | 5 +++ pyiceberg/table/snapshots.py | 1 + pyiceberg/table/update/__init__.py | 18 +++++++++ tests/conftest.py | 13 +++++- tests/table/test_init.py | 63 ++++++++++++++++++++++++++++++ 6 files changed, 126 insertions(+), 1 deletion(-) create mode 100644 pyiceberg/table/encryption.py diff --git a/pyiceberg/table/encryption.py b/pyiceberg/table/encryption.py new file mode 100644 index 0000000000..ecad16567c --- /dev/null +++ b/pyiceberg/table/encryption.py @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional +from pydantic import Field +from pyiceberg.typedef import IcebergBaseModel + + +class EncryptedKey(IcebergBaseModel): + key_id: str = Field(alias="key-id", description="ID of the encryption key") + encrypted_key_metadata: bytes = Field(alias="encrypted-key-metadata", description="Encrypted key and metadata, base64 encoded") + encrypted_by_id: Optional[str] = Field(alias="encrypted-by-id", description="Optional ID of the key used to encrypt or wrap `key-metadata`", default=None) + properties: Optional[dict[str, str]] = Field(alias="properties", description="A string to string map of additional metadata used by the table's encryption scheme", default=None) \ No newline at end of file diff --git a/pyiceberg/table/metadata.py b/pyiceberg/table/metadata.py index f248700c02..96e8bd80d3 100644 --- a/pyiceberg/table/metadata.py +++ b/pyiceberg/table/metadata.py @@ -27,6 +27,7 @@ from pyiceberg.exceptions import ValidationError from pyiceberg.partitioning import PARTITION_FIELD_ID_START, PartitionSpec, assign_fresh_partition_spec_ids from pyiceberg.schema import Schema, assign_fresh_schema_ids +from pyiceberg.table.encryption import EncryptedKey from pyiceberg.table.name_mapping import NameMapping, parse_mapping_from_json from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef, SnapshotRefType from pyiceberg.table.snapshots import MetadataLogEntry, Snapshot, SnapshotLogEntry @@ -516,6 +517,7 @@ class TableMetadataV3(TableMetadataCommonFields, IcebergBaseModel): - Multi-argument transforms for partitioning and sorting - Row Lineage tracking - Binary deletion vectors + - Encryption Keys For more information: https://iceberg.apache.org/spec/?column-projection#version-3-extended-types-and-capabilities @@ -552,6 +554,9 @@ def construct_refs(cls, table_metadata: TableMetadata) -> TableMetadata: next_row_id: Optional[int] = Field(alias="next-row-id", default=None) """A long higher than all assigned row IDs; the next snapshot's `first-row-id`.""" + encryption_keys: List[EncryptedKey] = Field(alias="encryption-keys", default=[]) + """The list of encryption keys for this table.""" + def model_dump_json( self, exclude_none: bool = True, exclude: Optional[Any] = None, by_alias: bool = True, **kwargs: Any ) -> str: diff --git a/pyiceberg/table/snapshots.py b/pyiceberg/table/snapshots.py index 8d1a24c420..53c1d10571 100644 --- a/pyiceberg/table/snapshots.py +++ b/pyiceberg/table/snapshots.py @@ -243,6 +243,7 @@ class Snapshot(IcebergBaseModel): manifest_list: str = Field(alias="manifest-list", description="Location of the snapshot's manifest list file") summary: Optional[Summary] = Field(default=None) schema_id: Optional[int] = Field(alias="schema-id", default=None) + key_id: Optional[str] = Field(alias="key-id", default=None, description="The id of the encryption key") def __str__(self) -> str: """Return the string representation of the Snapshot class.""" diff --git a/pyiceberg/table/update/__init__.py b/pyiceberg/table/update/__init__.py index 4905c31bfb..71a66c3b61 100644 --- a/pyiceberg/table/update/__init__.py +++ b/pyiceberg/table/update/__init__.py @@ -28,6 +28,7 @@ from pyiceberg.exceptions import CommitFailedException from pyiceberg.partitioning import PARTITION_FIELD_ID_START, PartitionSpec from pyiceberg.schema import Schema +from pyiceberg.table.encryption import EncryptedKey from pyiceberg.table.metadata import SUPPORTED_TABLE_FORMAT_VERSION, TableMetadata, TableMetadataUtil from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef from pyiceberg.table.snapshots import ( @@ -84,6 +85,13 @@ class UpgradeFormatVersionUpdate(IcebergBaseModel): action: Literal["upgrade-format-version"] = Field(default="upgrade-format-version") format_version: int = Field(alias="format-version") +class AddEncryptedKeyUpdate(IcebergBaseModel): + action: Literal["add-encryption-key"] = Field(default="add-encryption-key") + key: EncryptedKey = Field(alias="key") + +class RemoveEncryptedKeyUpdate(IcebergBaseModel): + action: Literal["remove-encryption-key"] = Field(default="remove-encryption-key") + key_id: str = Field(alias="key-id") class AddSchemaUpdate(IcebergBaseModel): action: Literal["add-schema"] = Field(default="add-schema") @@ -217,6 +225,8 @@ class RemoveStatisticsUpdate(IcebergBaseModel): RemovePropertiesUpdate, SetStatisticsUpdate, RemoveStatisticsUpdate, + AddEncryptedKeyUpdate, + RemoveEncryptedKeyUpdate, ], Field(discriminator="action"), ] @@ -581,6 +591,14 @@ def _(update: RemoveStatisticsUpdate, base_metadata: TableMetadata, context: _Ta return base_metadata.model_copy(update={"statistics": statistics}) +@_apply_table_update.register(AddEncryptedKeyUpdate) +def _(update: AddEncryptedKeyUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: + context.add_update(update) + + if base_metadata.format_version <= 2: + raise ValueError("Cannot add encryption keys to Iceberg v1 or v2 tables") + + return base_metadata.model_copy(update={"encryption_keys": base_metadata.encryption_keys + [update.key]}) def update_table_metadata( base_metadata: TableMetadata, diff --git a/tests/conftest.py b/tests/conftest.py index 729e29cb0c..e36218ef4e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -64,7 +64,7 @@ from pyiceberg.schema import Accessor, Schema from pyiceberg.serializers import ToOutputFile from pyiceberg.table import FileScanTask, Table -from pyiceberg.table.metadata import TableMetadataV1, TableMetadataV2 +from pyiceberg.table.metadata import TableMetadataV1, TableMetadataV2, TableMetadataV3 from pyiceberg.types import ( BinaryType, BooleanType, @@ -2341,6 +2341,17 @@ def table_v2(example_table_metadata_v2: Dict[str, Any]) -> Table: catalog=NoopCatalog("NoopCatalog"), ) +@pytest.fixture +def table_v3(example_table_metadata_v3: Dict[str, Any]) -> Table: + table_metadata = TableMetadataV3(**example_table_metadata_v3) + return Table( + identifier=("database", "table"), + metadata=table_metadata, + metadata_location=f"{table_metadata.location}/uuid.metadata.json", + io=load_file_io(), + catalog=NoopCatalog("NoopCatalog"), + ) + @pytest.fixture def table_v2_with_fixed_and_decimal_types( diff --git a/tests/table/test_init.py b/tests/table/test_init.py index 6165dadec4..b4461983bb 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint:disable=redefined-outer-name +import base64 import json import uuid from copy import copy @@ -49,6 +50,7 @@ TableIdentifier, _match_deletes_to_data_file, ) +from pyiceberg.table.encryption import EncryptedKey from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER, TableMetadataUtil, TableMetadataV2, _generate_snapshot_id from pyiceberg.table.refs import SnapshotRef from pyiceberg.table.snapshots import ( @@ -66,6 +68,7 @@ ) from pyiceberg.table.statistics import BlobMetadata, StatisticsFile from pyiceberg.table.update import ( + AddEncryptedKeyUpdate, AddSnapshotUpdate, AddSortOrderUpdate, AssertCreate, @@ -1345,3 +1348,63 @@ def test_remove_statistics_update(table_v2_with_statistics: Table) -> None: table_v2_with_statistics.metadata, (RemoveStatisticsUpdate(snapshot_id=123456789),), ) + +def test_add_encryption_key(table_v3: Table) -> None: + update = AddEncryptedKeyUpdate( + key=EncryptedKey( + key_id="test", + encrypted_key_metadata=base64.b64encode("hello".encode('utf-8')) + ) + ) + + expected = """ + { + "key-id": "test", + "encrypted-key-metadata": "aGVsbG8=" + }""" + + assert table_v3.metadata.encryption_keys == [] + add_metadata = update_table_metadata(table_v3.metadata, (update,)) + assert len(add_metadata.encryption_keys) == 1 + + assert json.loads(add_metadata.encryption_keys[0].model_dump_json()) == json.loads(expected) + +def test_remove_encryption_key(table_v3: Table) -> None: + update_add = AddEncryptedKeyUpdate( + key=EncryptedKey( + key_id="test", + encrypted_key_metadata=base64.b64encode("hello".encode('utf-8')) + ) + ) + add_metadata = update_table_metadata(table_v3.metadata, (update_add,)) + assert len(add_metadata.encryption_keys) == 1 + + update_remove = RemoveEncryptedKeyUpdate(key_id="test") + remove_metadata = update_table_metadata(add_metadata, (update_remove,)) + assert len(remove_metadata.encryption_keys) == 0 + + +def test_remove_non_existent_encryption_key(table_v3: Table) -> None: + update_add = AddEncryptedKeyUpdate( + key=EncryptedKey( + key_id="test", + encrypted_key_metadata=base64.b64encode("hello".encode('utf-8')) + ) + ) + add_metadata = update_table_metadata(table_v3.metadata, (update_add,)) + assert len(add_metadata.encryption_keys) == 1 + + update_remove = RemoveEncryptedKeyUpdate(key_id="non_existent_key") + remove_metadata = update_table_metadata(add_metadata, (update_remove,)) + assert len(remove_metadata.encryption_keys) == 1 # Should be a no-op + + +def test_add_remove_encryption_key_v2_table(table_v2: Table) -> None: + update_add = AddEncryptedKeyUpdate( + key=EncryptedKey( + key_id="test_v2", + encrypted_key_metadata=base64.b64encode("hello_v2".encode('utf-8')) + ) + ) + with pytest.raises(ValueError, match=r"Cannot add encryption keys from Iceberg v1 or v2 table"): + update_table_metadata(table_v2.metadata, (update_add,)) From 86543cb1cfcf579eb518e7475226254a1827000a Mon Sep 17 00:00:00 2001 From: Alex Stephen Date: Wed, 18 Jun 2025 13:39:40 -0700 Subject: [PATCH 2/2] it works --- pyiceberg/table/encryption.py | 16 ++++++++++++--- pyiceberg/table/update/__init__.py | 5 +++++ tests/conftest.py | 1 + tests/table/test_init.py | 31 +++++++----------------------- 4 files changed, 26 insertions(+), 27 deletions(-) diff --git a/pyiceberg/table/encryption.py b/pyiceberg/table/encryption.py index ecad16567c..4cb1c67cdd 100644 --- a/pyiceberg/table/encryption.py +++ b/pyiceberg/table/encryption.py @@ -16,12 +16,22 @@ # under the License. from typing import Optional + from pydantic import Field + from pyiceberg.typedef import IcebergBaseModel class EncryptedKey(IcebergBaseModel): key_id: str = Field(alias="key-id", description="ID of the encryption key") - encrypted_key_metadata: bytes = Field(alias="encrypted-key-metadata", description="Encrypted key and metadata, base64 encoded") - encrypted_by_id: Optional[str] = Field(alias="encrypted-by-id", description="Optional ID of the key used to encrypt or wrap `key-metadata`", default=None) - properties: Optional[dict[str, str]] = Field(alias="properties", description="A string to string map of additional metadata used by the table's encryption scheme", default=None) \ No newline at end of file + encrypted_key_metadata: bytes = Field( + alias="encrypted-key-metadata", description="Encrypted key and metadata, base64 encoded" + ) + encrypted_by_id: Optional[str] = Field( + alias="encrypted-by-id", description="Optional ID of the key used to encrypt or wrap `key-metadata`", default=None + ) + properties: Optional[dict[str, str]] = Field( + alias="properties", + description="A string to string map of additional metadata used by the table's encryption scheme", + default=None, + ) diff --git a/pyiceberg/table/update/__init__.py b/pyiceberg/table/update/__init__.py index 71a66c3b61..6d918bc5c1 100644 --- a/pyiceberg/table/update/__init__.py +++ b/pyiceberg/table/update/__init__.py @@ -85,14 +85,17 @@ class UpgradeFormatVersionUpdate(IcebergBaseModel): action: Literal["upgrade-format-version"] = Field(default="upgrade-format-version") format_version: int = Field(alias="format-version") + class AddEncryptedKeyUpdate(IcebergBaseModel): action: Literal["add-encryption-key"] = Field(default="add-encryption-key") key: EncryptedKey = Field(alias="key") + class RemoveEncryptedKeyUpdate(IcebergBaseModel): action: Literal["remove-encryption-key"] = Field(default="remove-encryption-key") key_id: str = Field(alias="key-id") + class AddSchemaUpdate(IcebergBaseModel): action: Literal["add-schema"] = Field(default="add-schema") schema_: Schema = Field(alias="schema") @@ -591,6 +594,7 @@ def _(update: RemoveStatisticsUpdate, base_metadata: TableMetadata, context: _Ta return base_metadata.model_copy(update={"statistics": statistics}) + @_apply_table_update.register(AddEncryptedKeyUpdate) def _(update: AddEncryptedKeyUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: context.add_update(update) @@ -600,6 +604,7 @@ def _(update: AddEncryptedKeyUpdate, base_metadata: TableMetadata, context: _Tab return base_metadata.model_copy(update={"encryption_keys": base_metadata.encryption_keys + [update.key]}) + def update_table_metadata( base_metadata: TableMetadata, updates: Tuple[TableUpdate, ...], diff --git a/tests/conftest.py b/tests/conftest.py index e36218ef4e..67ee836a4b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2341,6 +2341,7 @@ def table_v2(example_table_metadata_v2: Dict[str, Any]) -> Table: catalog=NoopCatalog("NoopCatalog"), ) + @pytest.fixture def table_v3(example_table_metadata_v3: Dict[str, Any]) -> Table: table_metadata = TableMetadataV3(**example_table_metadata_v3) diff --git a/tests/table/test_init.py b/tests/table/test_init.py index b4461983bb..dee752cabd 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -79,6 +79,7 @@ AssertLastAssignedPartitionId, AssertRefSnapshotId, AssertTableUUID, + RemoveEncryptedKeyUpdate, RemovePropertiesUpdate, RemoveSnapshotRefUpdate, RemoveSnapshotsUpdate, @@ -1349,13 +1350,9 @@ def test_remove_statistics_update(table_v2_with_statistics: Table) -> None: (RemoveStatisticsUpdate(snapshot_id=123456789),), ) + def test_add_encryption_key(table_v3: Table) -> None: - update = AddEncryptedKeyUpdate( - key=EncryptedKey( - key_id="test", - encrypted_key_metadata=base64.b64encode("hello".encode('utf-8')) - ) - ) + update = AddEncryptedKeyUpdate(key=EncryptedKey(key_id="test", encrypted_key_metadata=base64.b64encode(b"hello"))) expected = """ { @@ -1369,13 +1366,9 @@ def test_add_encryption_key(table_v3: Table) -> None: assert json.loads(add_metadata.encryption_keys[0].model_dump_json()) == json.loads(expected) + def test_remove_encryption_key(table_v3: Table) -> None: - update_add = AddEncryptedKeyUpdate( - key=EncryptedKey( - key_id="test", - encrypted_key_metadata=base64.b64encode("hello".encode('utf-8')) - ) - ) + update_add = AddEncryptedKeyUpdate(key=EncryptedKey(key_id="test", encrypted_key_metadata=base64.b64encode(b"hello"))) add_metadata = update_table_metadata(table_v3.metadata, (update_add,)) assert len(add_metadata.encryption_keys) == 1 @@ -1385,12 +1378,7 @@ def test_remove_encryption_key(table_v3: Table) -> None: def test_remove_non_existent_encryption_key(table_v3: Table) -> None: - update_add = AddEncryptedKeyUpdate( - key=EncryptedKey( - key_id="test", - encrypted_key_metadata=base64.b64encode("hello".encode('utf-8')) - ) - ) + update_add = AddEncryptedKeyUpdate(key=EncryptedKey(key_id="test", encrypted_key_metadata=base64.b64encode(b"hello"))) add_metadata = update_table_metadata(table_v3.metadata, (update_add,)) assert len(add_metadata.encryption_keys) == 1 @@ -1400,11 +1388,6 @@ def test_remove_non_existent_encryption_key(table_v3: Table) -> None: def test_add_remove_encryption_key_v2_table(table_v2: Table) -> None: - update_add = AddEncryptedKeyUpdate( - key=EncryptedKey( - key_id="test_v2", - encrypted_key_metadata=base64.b64encode("hello_v2".encode('utf-8')) - ) - ) + update_add = AddEncryptedKeyUpdate(key=EncryptedKey(key_id="test_v2", encrypted_key_metadata=base64.b64encode(b"hello_v2"))) with pytest.raises(ValueError, match=r"Cannot add encryption keys from Iceberg v1 or v2 table"): update_table_metadata(table_v2.metadata, (update_add,))