diff --git a/pyiceberg/avro/file.py b/pyiceberg/avro/file.py index d0da7651b7..1e3dbfd3df 100644 --- a/pyiceberg/avro/file.py +++ b/pyiceberg/avro/file.py @@ -228,6 +228,7 @@ class AvroOutputFile(Generic[D]): encoder: BinaryEncoder sync_bytes: bytes writer: Writer + closed: bool def __init__( self, @@ -247,6 +248,7 @@ def __init__( else resolve_writer(record_schema=record_schema, file_schema=self.file_schema) ) self.metadata = metadata + self.closed = False def __enter__(self) -> AvroOutputFile[D]: """ @@ -267,6 +269,7 @@ def __exit__( ) -> None: """Perform cleanup when exiting the scope of a 'with' statement.""" self.output_stream.close() + self.closed = True def _write_header(self) -> None: json_schema = json.dumps(AvroSchemaConversion().iceberg_to_avro(self.file_schema, schema_name=self.schema_name)) @@ -285,3 +288,9 @@ def write_block(self, objects: List[D]) -> None: self.encoder.write_int(len(block_content)) self.encoder.write(block_content) self.encoder.write(self.sync_bytes) + + def __len__(self) -> int: + """Return the total number number of bytes written.""" + if self.closed: + return len(self.output_file) + return self.output_stream.tell() diff --git a/pyiceberg/io/__init__.py b/pyiceberg/io/__init__.py index d200874741..d8a264ab29 100644 --- a/pyiceberg/io/__init__.py +++ b/pyiceberg/io/__init__.py @@ -136,6 +136,10 @@ def __exit__( ) -> None: """Perform cleanup when exiting the scope of a 'with' statement.""" + @abstractmethod + def tell(self) -> int: + """Return the total number number of bytes written to the stream.""" + class InputFile(ABC): """A base class for InputFile implementations. diff --git a/pyiceberg/manifest.py b/pyiceberg/manifest.py index 960952d02d..f125221ef6 100644 --- a/pyiceberg/manifest.py +++ b/pyiceberg/manifest.py @@ -840,6 +840,81 @@ def existing(self, entry: ManifestEntry) -> ManifestWriter: return self + def __len__(self) -> int: + """Return the total number number of bytes written.""" + return len(self._writer) + + +class RollingManifestWriter: + closed: bool + _supplier: Generator[ManifestWriter, None, None] + _manifest_files: list[ManifestFile] + _target_file_size_in_bytes: int + _target_number_of_rows: int + _current_writer: Optional[ManifestWriter] + _current_file_rows: int + + def __init__( + self, supplier: Generator[ManifestWriter, None, None], target_file_size_in_bytes, target_number_of_rows + ) -> None: + self._closed = False + self._manifest_files = [] + self._supplier = supplier + self._target_file_size_in_bytes = target_file_size_in_bytes + self._target_number_of_rows = target_number_of_rows + self._current_writer = None + self._current_file_rows = 0 + + def __enter__(self) -> RollingManifestWriter: + self._get_current_writer().__enter__() + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + self._close_current_writer() + self._closed = True + + def _get_current_writer(self) -> ManifestWriter: + if self._should_roll_to_new_file(): + self._close_current_writer() + if not self._current_writer: + self._current_writer = next(self._supplier) + self._current_writer.__enter__() + return self._current_writer + return self._current_writer + + def _should_roll_to_new_file(self) -> bool: + if not self._current_writer: + return False + return ( + self._current_file_rows >= self._target_number_of_rows or len(self._current_writer) >= self._target_file_size_in_bytes + ) + + def _close_current_writer(self): + if self._current_writer: + self._current_writer.__exit__(None, None, None) + current_file = self._current_writer.to_manifest_file() + self._manifest_files.append(current_file) + self._current_writer = None + self._current_file_rows = 0 + + def to_manifest_files(self) -> list[ManifestFile]: + if not self._closed: + raise RuntimeError("Cannot create manifest files from unclosed writer") + return self._manifest_files + + def add_entry(self, entry: ManifestEntry) -> RollingManifestWriter: + if self._closed: + raise RuntimeError("Cannot add entry to closed manifest writer") + self._get_current_writer().add_entry(entry) + self._current_file_rows += entry.data_file.record_count + return self + + class ManifestWriterV1(ManifestWriter): def __init__(self, spec: PartitionSpec, schema: Schema, output_file: OutputFile, snapshot_id: int): super().__init__( diff --git a/tests/utils/test_manifest.py b/tests/utils/test_manifest.py index ef33b16b00..82750fe871 100644 --- a/tests/utils/test_manifest.py +++ b/tests/utils/test_manifest.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=redefined-outer-name,arguments-renamed,fixme from tempfile import TemporaryDirectory -from typing import Dict +from typing import Dict, Generator import fastavro import pytest @@ -30,7 +30,9 @@ ManifestContent, ManifestEntryStatus, ManifestFile, + ManifestWriter, PartitionFieldSummary, + RollingManifestWriter, read_manifest_list, write_manifest, write_manifest_list, @@ -493,6 +495,75 @@ def test_write_manifest( assert data_file.sort_order_id == 0 +@pytest.mark.parametrize("format_version", [1, 2]) +@pytest.mark.parametrize( + "target_number_of_rows,target_file_size_in_bytes,expected_number_of_files", + [ + (19514, 388873, 1), # should not roll over + (19513, 388873, 2), # should roll over due to target_rows + (4000, 388872, 2), # should roll over due target_bytes + (4000, 388872, 2), # should roll over due to target_rows and target_bytes + ], +) +def test_rolling_manifest_writer( + generated_manifest_file_file_v1: str, + generated_manifest_file_file_v2: str, + format_version: TableVersion, + target_number_of_rows: int, + target_file_size_in_bytes: int, + expected_number_of_files: int, +) -> None: + io = load_file_io() + snapshot = Snapshot( + snapshot_id=25, + parent_snapshot_id=19, + timestamp_ms=1602638573590, + manifest_list=generated_manifest_file_file_v1 if format_version == 1 else generated_manifest_file_file_v2, + summary=Summary(Operation.APPEND), + schema_id=3, + ) + demo_manifest_file = snapshot.manifests(io)[0] + manifest_entries = demo_manifest_file.fetch_manifest_entry(io) + test_schema = Schema( + NestedField(1, "VendorID", IntegerType(), False), NestedField(2, "tpep_pickup_datetime", IntegerType(), False) + ) + test_spec = PartitionSpec( + PartitionField(source_id=1, field_id=1, transform=IdentityTransform(), name="VendorID"), + PartitionField(source_id=2, field_id=2, transform=IdentityTransform(), name="tpep_pickup_datetime"), + spec_id=demo_manifest_file.partition_spec_id, + ) + + with TemporaryDirectory() as tmpdir: + + def supplier() -> Generator[ManifestWriter, None, None]: + i = 0 + while True: + tmp_avro_file = tmpdir + f"/test_write_manifest-{i}.avro" + output = io.new_output(tmp_avro_file) + yield write_manifest( + format_version=format_version, + spec=test_spec, + schema=test_schema, + output_file=output, + snapshot_id=8744736658442914487, + ) + i += 1 + + with RollingManifestWriter( + supplier=supplier(), + target_file_size_in_bytes=target_file_size_in_bytes, + target_number_of_rows=target_number_of_rows, + ) as writer: + for entry in manifest_entries: + writer.add_entry(entry) + + manifest_files = writer.to_manifest_files() + assert len(manifest_files) == expected_number_of_files + with pytest.raises(RuntimeError): + # It is already closed + writer.add_entry(manifest_entries[0]) + + @pytest.mark.parametrize("format_version", [1, 2]) def test_write_manifest_list( generated_manifest_file_file_v1: str, generated_manifest_file_file_v2: str, format_version: TableVersion @@ -560,3 +631,5 @@ def test_write_manifest_list( assert entry.file_sequence_number == 0 if format_version == 1 else 3 assert entry.snapshot_id == 8744736658442914487 assert entry.status == ManifestEntryStatus.ADDED + +