Skip to content

feat: add RollingManifestWriter #650

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
9 changes: 9 additions & 0 deletions pyiceberg/avro/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ class AvroOutputFile(Generic[D]):
encoder: BinaryEncoder
sync_bytes: bytes
writer: Writer
closed: bool

def __init__(
self,
Expand All @@ -247,6 +248,7 @@ def __init__(
else resolve_writer(record_schema=record_schema, file_schema=self.file_schema)
)
self.metadata = metadata
self.closed = False

def __enter__(self) -> AvroOutputFile[D]:
"""
Expand All @@ -267,6 +269,7 @@ def __exit__(
) -> None:
"""Perform cleanup when exiting the scope of a 'with' statement."""
self.output_stream.close()
self.closed = True

def _write_header(self) -> None:
json_schema = json.dumps(AvroSchemaConversion().iceberg_to_avro(self.file_schema, schema_name=self.schema_name))
Expand All @@ -285,3 +288,9 @@ def write_block(self, objects: List[D]) -> None:
self.encoder.write_int(len(block_content))
self.encoder.write(block_content)
self.encoder.write(self.sync_bytes)

def __len__(self) -> int:
"""Return the total number number of bytes written."""
if self.closed:
return len(self.output_file)
return self.output_stream.tell()
4 changes: 4 additions & 0 deletions pyiceberg/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ def __exit__(
) -> None:
"""Perform cleanup when exiting the scope of a 'with' statement."""

@abstractmethod
def tell(self) -> int:
"""Return the total number number of bytes written to the stream."""


class InputFile(ABC):
"""A base class for InputFile implementations.
Expand Down
75 changes: 75 additions & 0 deletions pyiceberg/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,81 @@ def existing(self, entry: ManifestEntry) -> ManifestWriter:
return self


def __len__(self) -> int:
"""Return the total number number of bytes written."""
return len(self._writer)


class RollingManifestWriter:
closed: bool
_supplier: Generator[ManifestWriter, None, None]
_manifest_files: list[ManifestFile]
_target_file_size_in_bytes: int
_target_number_of_rows: int
_current_writer: Optional[ManifestWriter]
_current_file_rows: int

def __init__(
self, supplier: Generator[ManifestWriter, None, None], target_file_size_in_bytes, target_number_of_rows
) -> None:
self._closed = False
self._manifest_files = []
self._supplier = supplier
self._target_file_size_in_bytes = target_file_size_in_bytes
self._target_number_of_rows = target_number_of_rows
self._current_writer = None
self._current_file_rows = 0

def __enter__(self) -> RollingManifestWriter:
self._get_current_writer().__enter__()
return self

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
self._close_current_writer()
self._closed = True

def _get_current_writer(self) -> ManifestWriter:
if self._should_roll_to_new_file():
self._close_current_writer()
if not self._current_writer:
self._current_writer = next(self._supplier)
self._current_writer.__enter__()
return self._current_writer
return self._current_writer

def _should_roll_to_new_file(self) -> bool:
if not self._current_writer:
return False
return (
self._current_file_rows >= self._target_number_of_rows or len(self._current_writer) >= self._target_file_size_in_bytes
)

def _close_current_writer(self):
if self._current_writer:
self._current_writer.__exit__(None, None, None)
current_file = self._current_writer.to_manifest_file()
self._manifest_files.append(current_file)
self._current_writer = None
self._current_file_rows = 0

def to_manifest_files(self) -> list[ManifestFile]:
if not self._closed:
raise RuntimeError("Cannot create manifest files from unclosed writer")
return self._manifest_files

def add_entry(self, entry: ManifestEntry) -> RollingManifestWriter:
if self._closed:
raise RuntimeError("Cannot add entry to closed manifest writer")
self._get_current_writer().add_entry(entry)
self._current_file_rows += entry.data_file.record_count
return self


class ManifestWriterV1(ManifestWriter):
def __init__(self, spec: PartitionSpec, schema: Schema, output_file: OutputFile, snapshot_id: int):
super().__init__(
Expand Down
75 changes: 74 additions & 1 deletion tests/utils/test_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
# pylint: disable=redefined-outer-name,arguments-renamed,fixme
from tempfile import TemporaryDirectory
from typing import Dict
from typing import Dict, Generator

import fastavro
import pytest
Expand All @@ -30,7 +30,9 @@
ManifestContent,
ManifestEntryStatus,
ManifestFile,
ManifestWriter,
PartitionFieldSummary,
RollingManifestWriter,
read_manifest_list,
write_manifest,
write_manifest_list,
Expand Down Expand Up @@ -493,6 +495,75 @@ def test_write_manifest(
assert data_file.sort_order_id == 0


@pytest.mark.parametrize("format_version", [1, 2])
@pytest.mark.parametrize(
"target_number_of_rows,target_file_size_in_bytes,expected_number_of_files",
[
(19514, 388873, 1), # should not roll over
(19513, 388873, 2), # should roll over due to target_rows
(4000, 388872, 2), # should roll over due target_bytes
(4000, 388872, 2), # should roll over due to target_rows and target_bytes
],
)
def test_rolling_manifest_writer(
generated_manifest_file_file_v1: str,
generated_manifest_file_file_v2: str,
format_version: TableVersion,
target_number_of_rows: int,
target_file_size_in_bytes: int,
expected_number_of_files: int,
) -> None:
io = load_file_io()
snapshot = Snapshot(
snapshot_id=25,
parent_snapshot_id=19,
timestamp_ms=1602638573590,
manifest_list=generated_manifest_file_file_v1 if format_version == 1 else generated_manifest_file_file_v2,
summary=Summary(Operation.APPEND),
schema_id=3,
)
demo_manifest_file = snapshot.manifests(io)[0]
manifest_entries = demo_manifest_file.fetch_manifest_entry(io)
test_schema = Schema(
NestedField(1, "VendorID", IntegerType(), False), NestedField(2, "tpep_pickup_datetime", IntegerType(), False)
)
test_spec = PartitionSpec(
PartitionField(source_id=1, field_id=1, transform=IdentityTransform(), name="VendorID"),
PartitionField(source_id=2, field_id=2, transform=IdentityTransform(), name="tpep_pickup_datetime"),
spec_id=demo_manifest_file.partition_spec_id,
)

with TemporaryDirectory() as tmpdir:

def supplier() -> Generator[ManifestWriter, None, None]:
i = 0
while True:
tmp_avro_file = tmpdir + f"/test_write_manifest-{i}.avro"
output = io.new_output(tmp_avro_file)
yield write_manifest(
format_version=format_version,
spec=test_spec,
schema=test_schema,
output_file=output,
snapshot_id=8744736658442914487,
)
i += 1

with RollingManifestWriter(
supplier=supplier(),
target_file_size_in_bytes=target_file_size_in_bytes,
target_number_of_rows=target_number_of_rows,
) as writer:
for entry in manifest_entries:
writer.add_entry(entry)

manifest_files = writer.to_manifest_files()
assert len(manifest_files) == expected_number_of_files
with pytest.raises(RuntimeError):
# It is already closed
writer.add_entry(manifest_entries[0])


@pytest.mark.parametrize("format_version", [1, 2])
def test_write_manifest_list(
generated_manifest_file_file_v1: str, generated_manifest_file_file_v2: str, format_version: TableVersion
Expand Down Expand Up @@ -560,3 +631,5 @@ def test_write_manifest_list(
assert entry.file_sequence_number == 0 if format_version == 1 else 3
assert entry.snapshot_id == 8744736658442914487
assert entry.status == ManifestEntryStatus.ADDED