Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

typealias for table version #566

Merged
merged 7 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions pyiceberg/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from pyiceberg.io import FileIO, InputFile, OutputFile
from pyiceberg.partitioning import PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.typedef import EMPTY_DICT, Record
from pyiceberg.typedef import EMPTY_DICT, Record, TableVersion
from pyiceberg.types import (
BinaryType,
BooleanType,
Expand Down Expand Up @@ -302,7 +302,7 @@ def _(partition_field_type: PrimitiveType) -> PrimitiveType:
return partition_field_type


def data_file_with_partition(partition_type: StructType, format_version: Literal[1, 2]) -> StructType:
def data_file_with_partition(partition_type: StructType, format_version: TableVersion) -> StructType:
data_file_partition_type = StructType(*[
NestedField(
field_id=field.field_id,
Expand Down Expand Up @@ -372,7 +372,7 @@ def __setattr__(self, name: str, value: Any) -> None:
value = FileFormat[value]
super().__setattr__(name, value)

def __init__(self, format_version: Literal[1, 2] = DEFAULT_READ_VERSION, *data: Any, **named_data: Any) -> None:
def __init__(self, format_version: TableVersion = DEFAULT_READ_VERSION, *data: Any, **named_data: Any) -> None:
super().__init__(
*data,
**{"struct": DATA_FILE_TYPE[format_version], **named_data},
Expand Down Expand Up @@ -408,7 +408,7 @@ def __eq__(self, other: Any) -> bool:
MANIFEST_ENTRY_SCHEMAS_STRUCT = {format_version: schema.as_struct() for format_version, schema in MANIFEST_ENTRY_SCHEMAS.items()}


def manifest_entry_schema_with_data_file(format_version: Literal[1, 2], data_file: StructType) -> Schema:
def manifest_entry_schema_with_data_file(format_version: TableVersion, data_file: StructType) -> Schema:
return Schema(*[
NestedField(2, "data_file", data_file, required=True) if field.field_id == 2 else field
for field in MANIFEST_ENTRY_SCHEMAS[format_version].fields
Expand Down Expand Up @@ -719,9 +719,9 @@ def content(self) -> ManifestContent: ...

@property
@abstractmethod
def version(self) -> Literal[1, 2]: ...
def version(self) -> TableVersion: ...

def _with_partition(self, format_version: Literal[1, 2]) -> Schema:
def _with_partition(self, format_version: TableVersion) -> Schema:
data_file_type = data_file_with_partition(
format_version=format_version, partition_type=self._spec.partition_type(self._schema)
)
Expand Down Expand Up @@ -807,7 +807,7 @@ def content(self) -> ManifestContent:
return ManifestContent.DATA

@property
def version(self) -> Literal[1, 2]:
def version(self) -> TableVersion:
return 1

def prepare_entry(self, entry: ManifestEntry) -> ManifestEntry:
Expand All @@ -834,7 +834,7 @@ def content(self) -> ManifestContent:
return ManifestContent.DATA

@property
def version(self) -> Literal[1, 2]:
def version(self) -> TableVersion:
return 2

def prepare_entry(self, entry: ManifestEntry) -> ManifestEntry:
Expand All @@ -847,7 +847,7 @@ def prepare_entry(self, entry: ManifestEntry) -> ManifestEntry:


def write_manifest(
format_version: Literal[1, 2], spec: PartitionSpec, schema: Schema, output_file: OutputFile, snapshot_id: int
format_version: TableVersion, spec: PartitionSpec, schema: Schema, output_file: OutputFile, snapshot_id: int
) -> ManifestWriter:
if format_version == 1:
return ManifestWriterV1(spec, schema, output_file, snapshot_id)
Expand All @@ -858,14 +858,14 @@ def write_manifest(


class ManifestListWriter(ABC):
_format_version: Literal[1, 2]
_format_version: TableVersion
_output_file: OutputFile
_meta: Dict[str, str]
_manifest_files: List[ManifestFile]
_commit_snapshot_id: int
_writer: AvroOutputFile[ManifestFile]

def __init__(self, format_version: Literal[1, 2], output_file: OutputFile, meta: Dict[str, Any]):
def __init__(self, format_version: TableVersion, output_file: OutputFile, meta: Dict[str, Any]):
self._format_version = format_version
self._output_file = output_file
self._meta = meta
Expand Down Expand Up @@ -957,7 +957,7 @@ def prepare_manifest(self, manifest_file: ManifestFile) -> ManifestFile:


def write_manifest_list(
format_version: Literal[1, 2],
format_version: TableVersion,
output_file: OutputFile,
snapshot_id: int,
parent_snapshot_id: Optional[int],
Expand Down
5 changes: 3 additions & 2 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
KeyDefaultDict,
Properties,
Record,
TableVersion,
)
from pyiceberg.types import (
IcebergType,
Expand Down Expand Up @@ -288,7 +289,7 @@ def _apply(self, updates: Tuple[TableUpdate, ...], requirements: Tuple[TableRequ

return self

def upgrade_table_version(self, format_version: Literal[1, 2]) -> Transaction:
def upgrade_table_version(self, format_version: TableVersion) -> Transaction:
"""Set the table to a certain version.

Args:
Expand Down Expand Up @@ -1018,7 +1019,7 @@ def scan(
)

@property
def format_version(self) -> Literal[1, 2]:
def format_version(self) -> TableVersion:
return self.metadata.format_version

def schema(self) -> Schema:
Expand Down
5 changes: 5 additions & 0 deletions pyiceberg/typedef.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Dict,
Generic,
List,
Literal,
Optional,
Protocol,
Set,
Expand All @@ -37,6 +38,7 @@
from uuid import UUID

from pydantic import BaseModel, ConfigDict, RootModel
from typing_extensions import TypeAlias

if TYPE_CHECKING:
from pyiceberg.types import StructType
Expand Down Expand Up @@ -199,3 +201,6 @@ def __repr__(self) -> str:
def record_fields(self) -> List[str]:
"""Return values of all the fields of the Record class except those specified in skip_fields."""
return [self.__getattribute__(v) if hasattr(self, v) else None for v in self._position_to_field_name]


TableVersion: TypeAlias = Literal[1, 2]
8 changes: 4 additions & 4 deletions tests/utils/test_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
# pylint: disable=redefined-outer-name,arguments-renamed,fixme
from tempfile import TemporaryDirectory
from typing import Dict, Literal
from typing import Dict

import fastavro
import pytest
Expand All @@ -39,7 +39,7 @@
from pyiceberg.schema import Schema
from pyiceberg.table.snapshots import Operation, Snapshot, Summary
from pyiceberg.transforms import IdentityTransform
from pyiceberg.typedef import Record
from pyiceberg.typedef import Record, TableVersion
from pyiceberg.types import IntegerType, NestedField


Expand Down Expand Up @@ -308,7 +308,7 @@ def test_read_manifest_v2(generated_manifest_file_file_v2: str) -> None:

@pytest.mark.parametrize("format_version", [1, 2])
def test_write_manifest(
generated_manifest_file_file_v1: str, generated_manifest_file_file_v2: str, format_version: Literal[1, 2]
generated_manifest_file_file_v1: str, generated_manifest_file_file_v2: str, format_version: TableVersion
) -> None:
io = load_file_io()
snapshot = Snapshot(
Expand Down Expand Up @@ -478,7 +478,7 @@ def test_write_manifest(

@pytest.mark.parametrize("format_version", [1, 2])
def test_write_manifest_list(
generated_manifest_file_file_v1: str, generated_manifest_file_file_v2: str, format_version: Literal[1, 2]
generated_manifest_file_file_v1: str, generated_manifest_file_file_v2: str, format_version: TableVersion
) -> None:
io = load_file_io()

Expand Down