Skip to content

Commit

Permalink
have requirements inherit from ValidatableTableRequirement
Browse files Browse the repository at this point in the history
  • Loading branch information
Kieran Higgins committed Mar 14, 2024
1 parent 9cf8bd7 commit c2492a4
Showing 1 changed file with 23 additions and 17 deletions.
40 changes: 23 additions & 17 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,13 +761,9 @@ def update_table_metadata(base_metadata: TableMetadata, updates: Tuple[TableUpda
return new_metadata.model_copy(deep=True)


class TableRequirement(IcebergBaseModel):
class ValidatableTableRequirement(IcebergBaseModel):
type: str

@classmethod
def get_subclasses(cls) -> tuple[str]:
return tuple(cls.__subclasses__())

@abstractmethod
def validate(self, base_metadata: Optional[TableMetadata]) -> None:
"""Validate the requirement against the base metadata.
Expand All @@ -781,7 +777,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
...


class AssertCreate(TableRequirement):
class AssertCreate(ValidatableTableRequirement):
"""The table must not already exist; used for create transactions."""

type: Literal["assert-create"] = Field(default="assert-create")
Expand All @@ -791,7 +787,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
raise CommitFailedException("Table already exists")


class AssertTableUUID(TableRequirement):
class AssertTableUUID(ValidatableTableRequirement):
"""The table UUID must match the requirement's `uuid`."""

type: Literal["assert-table-uuid"] = Field(default="assert-table-uuid")
Expand All @@ -804,7 +800,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
raise CommitFailedException(f"Table UUID does not match: {self.uuid} != {base_metadata.table_uuid}")


class AssertRefSnapshotId(TableRequirement):
class AssertRefSnapshotId(ValidatableTableRequirement):
"""The table branch or tag identified by the requirement's `ref` must reference the requirement's `snapshot-id`.
if `snapshot-id` is `null` or missing, the ref must not already exist.
Expand All @@ -829,7 +825,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
raise CommitFailedException(f"Requirement failed: branch or tag {self.ref} is missing, expected {self.snapshot_id}")


class AssertLastAssignedFieldId(TableRequirement):
class AssertLastAssignedFieldId(ValidatableTableRequirement):
"""The table's last assigned column id must match the requirement's `last-assigned-field-id`."""

type: Literal["assert-last-assigned-field-id"] = Field(default="assert-last-assigned-field-id")
Expand All @@ -844,7 +840,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
)


class AssertCurrentSchemaId(TableRequirement):
class AssertCurrentSchemaId(ValidatableTableRequirement):
"""The table's current schema id must match the requirement's `current-schema-id`."""

type: Literal["assert-current-schema-id"] = Field(default="assert-current-schema-id")
Expand All @@ -859,7 +855,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
)


class AssertLastAssignedPartitionId(TableRequirement):
class AssertLastAssignedPartitionId(ValidatableTableRequirement):
"""The table's last assigned partition id must match the requirement's `last-assigned-partition-id`."""

type: Literal["assert-last-assigned-partition-id"] = Field(default="assert-last-assigned-partition-id")
Expand All @@ -874,7 +870,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
)


class AssertDefaultSpecId(TableRequirement):
class AssertDefaultSpecId(ValidatableTableRequirement):
"""The table's default spec id must match the requirement's `default-spec-id`."""

type: Literal["assert-default-spec-id"] = Field(default="assert-default-spec-id")
Expand All @@ -889,7 +885,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
)


class AssertDefaultSortOrderId(TableRequirement):
class AssertDefaultSortOrderId(ValidatableTableRequirement):
"""The table's default sort order id must match the requirement's `default-sort-order-id`."""

type: Literal["assert-default-sort-order-id"] = Field(default="assert-default-sort-order-id")
Expand All @@ -903,10 +899,22 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
f"Requirement failed: default sort order id has changed: expected {self.default_sort_order_id}, found {base_metadata.default_sort_order_id}"
)

TableRequirement = Annotated[
Union[
AssertCreate,
AssertTableUUID,
AssertRefSnapshotId,
AssertLastAssignedFieldId,
AssertCurrentSchemaId,
AssertLastAssignedPartitionId,
AssertDefaultSpecId,
AssertDefaultSortOrderId,
],
Field(discriminator='type'),
]

UpdatesAndRequirements = Tuple[Tuple[TableUpdate, ...], Tuple[TableRequirement, ...]]


class Namespace(IcebergRootModel[List[str]]):
"""Reference to one or more levels of a namespace."""

Expand All @@ -925,9 +933,7 @@ class TableIdentifier(IcebergBaseModel):

class CommitTableRequest(IcebergBaseModel):
identifier: TableIdentifier = Field()
requirements: Tuple[Annotated[Union[TableRequirement.get_subclasses()], Field(discriminator='type')], ...] = Field(
default_factory=tuple
)
requirements: Tuple[TableRequirement, ...] = Field(default_factory=tuple)
updates: Tuple[TableUpdate, ...] = Field(default_factory=tuple)


Expand Down

0 comments on commit c2492a4

Please sign in to comment.