Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Fokko committed Apr 9, 2024
1 parent a5e988a commit 1723819
Show file tree
Hide file tree
Showing 8 changed files with 195 additions and 112 deletions.
66 changes: 63 additions & 3 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import logging
import os
import re
import uuid
from abc import ABC, abstractmethod
from concurrent.futures import Future
from copy import copy
Expand Down Expand Up @@ -126,7 +127,6 @@
visit,
visit_with_partner,
)
from pyiceberg.table import PropertyUtil, TableProperties, WriteTask
from pyiceberg.table.metadata import TableMetadata
from pyiceberg.table.name_mapping import NameMapping
from pyiceberg.transforms import TruncateTransform
Expand Down Expand Up @@ -159,7 +159,7 @@
from pyiceberg.utils.truncate import truncate_upper_bound_binary_string, truncate_upper_bound_text_string

if TYPE_CHECKING:
from pyiceberg.table import FileScanTask
from pyiceberg.table import FileScanTask, WriteTask

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1443,6 +1443,8 @@ class PyArrowStatisticsCollector(PreOrderSchemaVisitor[List[StatisticsCollector]
_default_mode: str

def __init__(self, schema: Schema, properties: Dict[str, str]):
from pyiceberg.table import TableProperties

self._schema = schema
self._properties = properties
self._default_mode = self._properties.get(
Expand Down Expand Up @@ -1478,6 +1480,8 @@ def map(
return k + v

def primitive(self, primitive: PrimitiveType) -> List[StatisticsCollector]:
from pyiceberg.table import TableProperties

column_name = self._schema.find_column_name(self._field_id)
if column_name is None:
return []
Expand Down Expand Up @@ -1774,7 +1778,9 @@ def data_file_statistics_from_parquet_metadata(
)


def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteTask]) -> Iterator[DataFile]:
def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterable["WriteTask"]) -> Iterator[DataFile]:
from pyiceberg.table import PropertyUtil, TableProperties

schema = table_metadata.schema()
arrow_file_schema = schema.as_arrow()
parquet_writer_kwargs = _get_parquet_writer_kwargs(table_metadata.properties)
Expand Down Expand Up @@ -1875,6 +1881,8 @@ def parquet_files_to_data_files(io: FileIO, table_metadata: TableMetadata, file_


def _get_parquet_writer_kwargs(table_properties: Properties) -> Dict[str, Any]:
from pyiceberg.table import PropertyUtil, TableProperties

for key_pattern in [
TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES,
TableProperties.PARQUET_PAGE_ROW_LIMIT,
Expand Down Expand Up @@ -1912,3 +1920,55 @@ def _get_parquet_writer_kwargs(table_properties: Properties) -> Dict[str, Any]:
default=TableProperties.PARQUET_PAGE_ROW_LIMIT_DEFAULT,
),
}


def _dataframe_to_data_files(
table_metadata: TableMetadata,
df: pa.Table,
io: FileIO,
write_uuid: Optional[uuid.UUID] = None,
counter: Optional[itertools.count[int]] = None,
) -> Iterable[DataFile]:
"""Convert a PyArrow table into a DataFile.
Returns:
An iterable that supplies datafiles that represent the table.
"""
from pyiceberg.table import PropertyUtil, TableProperties, WriteTask

counter = counter or itertools.count(0)
write_uuid = write_uuid or uuid.uuid4()
target_file_size: int = PropertyUtil.property_as_int( # type: ignore # The property is set with non-None value.
properties=table_metadata.properties,
property_name=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES,
default=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT,
)

if table_metadata.spec().is_unpartitioned():
yield from write_file(
io=io,
table_metadata=table_metadata,
tasks=iter([
WriteTask(write_uuid=write_uuid, task_id=next(counter), record_batches=batches, schema=table_metadata.schema())
for batches in bin_pack_arrow_table(df, target_file_size)
]),
)
else:
from pyiceberg.table import determine_partitions

partitions = determine_partitions(spec=table_metadata.spec(), schema=table_metadata.schema(), arrow_table=df)
yield from write_file(
io=io,
table_metadata=table_metadata,
tasks=iter([
WriteTask(
write_uuid=write_uuid,
task_id=next(counter),
record_batches=batches,
partition_key=partition.partition_key,
schema=table_metadata.schema(),
)
for partition in partitions
for batches in bin_pack_arrow_table(partition.arrow_table_partition, target_file_size)
]),
)
Loading

0 comments on commit 1723819

Please sign in to comment.