From c2492a4d1d2278eab7f569a2910cfa87d8780c69 Mon Sep 17 00:00:00 2001 From: Kieran Higgins Date: Thu, 14 Mar 2024 16:51:39 +0000 Subject: [PATCH] have requirements inherit from ValidatableTableRequirement --- pyiceberg/table/__init__.py | 40 +++++++++++++++++++++---------------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 3bb384d7b..8f8270eea 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -761,13 +761,9 @@ def update_table_metadata(base_metadata: TableMetadata, updates: Tuple[TableUpda return new_metadata.model_copy(deep=True) -class TableRequirement(IcebergBaseModel): +class ValidatableTableRequirement(IcebergBaseModel): type: str - @classmethod - def get_subclasses(cls) -> tuple[str]: - return tuple(cls.__subclasses__()) - @abstractmethod def validate(self, base_metadata: Optional[TableMetadata]) -> None: """Validate the requirement against the base metadata. @@ -781,7 +777,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None: ... -class AssertCreate(TableRequirement): +class AssertCreate(ValidatableTableRequirement): """The table must not already exist; used for create transactions.""" type: Literal["assert-create"] = Field(default="assert-create") @@ -791,7 +787,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None: raise CommitFailedException("Table already exists") -class AssertTableUUID(TableRequirement): +class AssertTableUUID(ValidatableTableRequirement): """The table UUID must match the requirement's `uuid`.""" type: Literal["assert-table-uuid"] = Field(default="assert-table-uuid") @@ -804,7 +800,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None: raise CommitFailedException(f"Table UUID does not match: {self.uuid} != {base_metadata.table_uuid}") -class AssertRefSnapshotId(TableRequirement): +class AssertRefSnapshotId(ValidatableTableRequirement): """The table branch or tag identified by the requirement's `ref` must reference the requirement's `snapshot-id`. if `snapshot-id` is `null` or missing, the ref must not already exist. @@ -829,7 +825,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None: raise CommitFailedException(f"Requirement failed: branch or tag {self.ref} is missing, expected {self.snapshot_id}") -class AssertLastAssignedFieldId(TableRequirement): +class AssertLastAssignedFieldId(ValidatableTableRequirement): """The table's last assigned column id must match the requirement's `last-assigned-field-id`.""" type: Literal["assert-last-assigned-field-id"] = Field(default="assert-last-assigned-field-id") @@ -844,7 +840,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None: ) -class AssertCurrentSchemaId(TableRequirement): +class AssertCurrentSchemaId(ValidatableTableRequirement): """The table's current schema id must match the requirement's `current-schema-id`.""" type: Literal["assert-current-schema-id"] = Field(default="assert-current-schema-id") @@ -859,7 +855,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None: ) -class AssertLastAssignedPartitionId(TableRequirement): +class AssertLastAssignedPartitionId(ValidatableTableRequirement): """The table's last assigned partition id must match the requirement's `last-assigned-partition-id`.""" type: Literal["assert-last-assigned-partition-id"] = Field(default="assert-last-assigned-partition-id") @@ -874,7 +870,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None: ) -class AssertDefaultSpecId(TableRequirement): +class AssertDefaultSpecId(ValidatableTableRequirement): """The table's default spec id must match the requirement's `default-spec-id`.""" type: Literal["assert-default-spec-id"] = Field(default="assert-default-spec-id") @@ -889,7 +885,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None: ) -class AssertDefaultSortOrderId(TableRequirement): +class AssertDefaultSortOrderId(ValidatableTableRequirement): """The table's default sort order id must match the requirement's `default-sort-order-id`.""" type: Literal["assert-default-sort-order-id"] = Field(default="assert-default-sort-order-id") @@ -903,10 +899,22 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None: f"Requirement failed: default sort order id has changed: expected {self.default_sort_order_id}, found {base_metadata.default_sort_order_id}" ) +TableRequirement = Annotated[ + Union[ + AssertCreate, + AssertTableUUID, + AssertRefSnapshotId, + AssertLastAssignedFieldId, + AssertCurrentSchemaId, + AssertLastAssignedPartitionId, + AssertDefaultSpecId, + AssertDefaultSortOrderId, + ], + Field(discriminator='type'), +] UpdatesAndRequirements = Tuple[Tuple[TableUpdate, ...], Tuple[TableRequirement, ...]] - class Namespace(IcebergRootModel[List[str]]): """Reference to one or more levels of a namespace.""" @@ -925,9 +933,7 @@ class TableIdentifier(IcebergBaseModel): class CommitTableRequest(IcebergBaseModel): identifier: TableIdentifier = Field() - requirements: Tuple[Annotated[Union[TableRequirement.get_subclasses()], Field(discriminator='type')], ...] = Field( - default_factory=tuple - ) + requirements: Tuple[TableRequirement, ...] = Field(default_factory=tuple) updates: Tuple[TableUpdate, ...] = Field(default_factory=tuple)