Skip to content

Commit

Permalink
add_files
Browse files Browse the repository at this point in the history
  • Loading branch information
sungwy committed Mar 8, 2024
1 parent 65e28fa commit e250ffc
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 74 deletions.
53 changes: 27 additions & 26 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@
visit,
visit_with_partner,
)
from pyiceberg.table import PropertyUtil, TableProperties, WriteTask
from pyiceberg.table import AddFileTask, PropertyUtil, TableProperties, WriteTask
from pyiceberg.table.metadata import TableMetadata
from pyiceberg.table.name_mapping import NameMapping
from pyiceberg.transforms import TruncateTransform
Expand Down Expand Up @@ -1772,31 +1772,32 @@ def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteT
return iter([data_file])


def parquet_file_to_data_file(io: FileIO, table_metadata: TableMetadata, file_path: str) -> DataFile:
input_file = io.new_input(file_path)
with input_file.open() as input_stream:
parquet_metadata = pq.read_metadata(input_stream)

schema = table_metadata.schema()
data_file = DataFile(
content=DataFileContent.DATA,
file_path=file_path,
file_format=FileFormat.PARQUET,
partition=Record(),
record_count=parquet_metadata.num_rows,
file_size_in_bytes=len(input_file),
sort_order_id=None,
spec_id=table_metadata.default_spec_id,
equality_ids=None,
key_metadata=None,
)
fill_parquet_file_metadata(
data_file=data_file,
parquet_metadata=parquet_metadata,
stats_columns=compute_statistics_plan(schema, table_metadata.properties),
parquet_column_mapping=parquet_path_to_id_mapping(schema),
)
return data_file
def parquet_files_to_data_files(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[AddFileTask]) -> Iterator[DataFile]:
for task in tasks:
input_file = io.new_input(task.file_path)
with input_file.open() as input_stream:
parquet_metadata = pq.read_metadata(input_stream)

schema = table_metadata.schema()
data_file = DataFile(
content=DataFileContent.DATA,
file_path=task.file_path,
file_format=FileFormat.PARQUET,
partition=task.partition_field_value,
record_count=parquet_metadata.num_rows,
file_size_in_bytes=len(input_file),
sort_order_id=None,
spec_id=table_metadata.default_spec_id,
equality_ids=None,
key_metadata=None,
)
fill_parquet_file_metadata(
data_file=data_file,
parquet_metadata=parquet_metadata,
stats_columns=compute_statistics_plan(schema, table_metadata.properties),
parquet_column_mapping=parquet_path_to_id_mapping(schema),
)
yield data_file


ICEBERG_UNCOMPRESSED_CODEC = "uncompressed"
Expand Down
42 changes: 27 additions & 15 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
Dict,
Generic,
Iterable,
Iterator,
List,
Literal,
Optional,
Expand All @@ -56,6 +57,7 @@
parser,
visitors,
)
from pyiceberg.expressions.literals import StringLiteral
from pyiceberg.expressions.visitors import _InclusiveMetricsEvaluator, inclusive_projection
from pyiceberg.io import FileIO, load_file_io
from pyiceberg.manifest import (
Expand Down Expand Up @@ -117,6 +119,7 @@
Identifier,
KeyDefaultDict,
Properties,
Record,
)
from pyiceberg.types import (
IcebergType,
Expand Down Expand Up @@ -1140,7 +1143,7 @@ def add_files(self, file_paths: List[str]) -> None:
Shorthand API for adding files as data files to the table.
Args:
files: The list of full file paths to be added as data files to the table
file_paths: The list of full file paths to be added as data files to the table
"""
if self.name_mapping() is None:
with self.transaction() as tx:
Expand Down Expand Up @@ -2449,17 +2452,8 @@ def generate_data_file_filename(self, extension: str) -> str:

@dataclass(frozen=True)
class AddFileTask:
write_uuid: uuid.UUID
task_id: int
df: pa.Table
sort_order_id: Optional[int] = None

# Later to be extended with partition information

def generate_data_file_filename(self, extension: str) -> str:
# Mimics the behavior in the Java API:
# https://github.com/apache/iceberg/blob/a582968975dd30ff4917fbbe999f1be903efac02/core/src/main/java/org/apache/iceberg/io/OutputFileFactory.java#L92-L101
return f"00000-{self.task_id}-{self.write_uuid}.{extension}"
file_path: str
partition_field_value: Record


def _new_manifest_path(location: str, num: int, commit_uuid: uuid.UUID) -> str:
Expand Down Expand Up @@ -2493,16 +2487,34 @@ def _dataframe_to_data_files(
yield from write_file(io=io, table_metadata=table_metadata, tasks=iter([WriteTask(write_uuid, next(counter), df)]))


def add_file_tasks_from_file_paths(file_paths: List[str], table_metadata: TableMetadata) -> Iterator[AddFileTask]:
partition_spec = table_metadata.spec()
partition_struct = partition_spec.partition_type(table_metadata.schema())

for file_path in file_paths:
# file_path = 's3://warehouse/default/part1=2024-03-04/part2=ABCD'
# ['part1=2024-03-04', 'part2=ABCD']
parts = [part for part in file_path.split("/") if "=" in part]

partition_field_values = {}
for part in parts:
partition_name, string_value = part.split("=")
if partition_field := partition_struct.field_by_name(partition_name):
partition_field_values[partition_name] = StringLiteral(string_value).to(partition_field.field_type).value

yield AddFileTask(file_path=file_path, partition_field_value=Record(**partition_field_values))


def _parquet_files_to_data_files(table_metadata: TableMetadata, file_paths: List[str], io: FileIO) -> Iterable[DataFile]:
"""Convert a list files into DataFiles.
Returns:
An iterable that supplies DataFiles that describe the parquet files.
"""
from pyiceberg.io.pyarrow import parquet_file_to_data_file
from pyiceberg.io.pyarrow import parquet_files_to_data_files

for file_path in file_paths:
yield parquet_file_to_data_file(io=io, table_metadata=table_metadata, file_path=file_path)
tasks = add_file_tasks_from_file_paths(file_paths, table_metadata)
yield from parquet_files_to_data_files(io=io, table_metadata=table_metadata, tasks=tasks)


class _MergingSnapshotProducer(UpdateTableMetadata["_MergingSnapshotProducer"]):
Expand Down
124 changes: 91 additions & 33 deletions tests/integration/test_add_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,75 +15,92 @@
# specific language governing permissions and limitations
# under the License.
# pylint:disable=redefined-outer-name

from datetime import date
from typing import Optional

import pyarrow as pa
import pyarrow.parquet as pq
import pytest
from pathlib import Path
from pyspark.sql import SparkSession

from pyiceberg.catalog import Catalog, Properties, Table
from pyiceberg.io.pyarrow import schema_to_pyarrow
from pyiceberg.catalog import Catalog, Table
from pyiceberg.exceptions import NoSuchTableError
from pyiceberg.partitioning import PartitionField, PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.transforms import IdentityTransform, MonthTransform
from pyiceberg.types import (
BooleanType,
DateType,
IntegerType,
NestedField,
StringType,
)
from pyiceberg.exceptions import NoSuchTableError
from pyspark.sql import SparkSession

TABLE_SCHEMA = Schema(
NestedField(field_id=1, name="foo", field_type=BooleanType(), required=False),
NestedField(field_id=2, name="bar", field_type=StringType(), required=False),
NestedField(field_id=4, name="baz", field_type=IntegerType(), required=False),
NestedField(field_id=10, name="qux", field_type=DateType(), required=False),
)

ARROW_SCHEMA = pa.schema([
("foo", pa.bool_()),
("bar", pa.string()),
("baz", pa.int32()),
("qux", pa.date32()),
])

ARROW_TABLE = pa.Table.from_pylist(
[
{
"foo": True,
"bar": "bar_string",
"baz": 123,
}
],
schema=schema_to_pyarrow(TABLE_SCHEMA),
)

def _create_table(session_catalog: Catalog, identifier: str) -> Table:
[
{
"foo": True,
"bar": "bar_string",
"baz": 123,
"qux": date(2024, 3, 7),
}
],
schema=ARROW_SCHEMA,
)

PARTITION_SPEC = PartitionSpec(
PartitionField(source_id=4, field_id=1000, transform=IdentityTransform(), name="baz"),
PartitionField(source_id=10, field_id=1001, transform=IdentityTransform(), name="qux"),
spec_id=0,
)


def _create_table(session_catalog: Catalog, identifier: str, partition_spec: Optional[PartitionSpec] = None) -> Table:
try:
session_catalog.drop_table(identifier=identifier)
except NoSuchTableError:
pass

tbl = session_catalog.create_table(identifier=identifier, schema=TABLE_SCHEMA)
tbl = session_catalog.create_table(
identifier=identifier, schema=TABLE_SCHEMA, partition_spec=partition_spec if partition_spec else PartitionSpec()
)

return tbl


@pytest.mark.integration
def test_add_files_to_unpartitioned_table(spark: SparkSession, session_catalog: Catalog, warehouse: Path) -> None:
def test_add_files_to_unpartitioned_table(spark: SparkSession, session_catalog: Catalog) -> None:
identifier = "default.unpartitioned_table"
tbl = _create_table(session_catalog, identifier)
# rows = spark.sql(
# f"""
# SELECT added_data_files_count, existing_data_files_count, deleted_data_files_count
# FROM {identifier}.all_manifests
# """
# ).collect()

# assert [row.added_data_files_count for row in rows] == []
# assert [row.existing_data_files_count for row in rows] == []
# assert [row.deleted_data_files_count for row in rows] == []

file_paths = [f"/{warehouse}/test-{i}.parquet" for i in range(5)]

file_paths = [f"s3://warehouse/default/unpartitioned/test-{i}.parquet" for i in range(5)]
# write parquet files
for file_path in file_paths:
fo = tbl.io.new_output(file_path)
with fo.create(overwrite=True) as fos:
with pq.ParquetWriter(fos, schema=ARROW_TABLE.schema) as writer:
with pq.ParquetWriter(fos, schema=ARROW_SCHEMA) as writer:
writer.write_table(ARROW_TABLE)

# add the parquet files as data files
tbl.add_files(file_paths)
tbl.add_files(file_paths=file_paths)

# NameMapping must have been set to enable reads
assert tbl.name_mapping() is not None

rows = spark.sql(
f"""
Expand All @@ -94,4 +111,45 @@ def test_add_files_to_unpartitioned_table(spark: SparkSession, session_catalog:

assert [row.added_data_files_count for row in rows] == [5]
assert [row.existing_data_files_count for row in rows] == [0]
assert [row.deleted_data_files_count for row in rows] == [0]
assert [row.deleted_data_files_count for row in rows] == [0]

df = spark.table(identifier)
assert df.count() == 5, "Expected 5 rows"
for col in df.columns:
assert df.filter(df[col].isNotNull()).count() == 5, "Expected all 5 rows to be non-null"


@pytest.mark.integration
def test_add_files_to_partitioned_table(spark: SparkSession, session_catalog: Catalog) -> None:
identifier = "default.partitioned_table"
tbl = _create_table(session_catalog, identifier, PARTITION_SPEC)

file_paths = [f"s3://warehouse/default/baz=123/qux=2024-03-07/test-{i}.parquet" for i in range(5)]
# write parquet files
for file_path in file_paths:
fo = tbl.io.new_output(file_path)
with fo.create(overwrite=True) as fos:
with pq.ParquetWriter(fos, schema=ARROW_SCHEMA) as writer:
writer.write_table(ARROW_TABLE)

# add the parquet files as data files
tbl.add_files(file_paths=file_paths)

# NameMapping must have been set to enable reads
assert tbl.name_mapping() is not None

rows = spark.sql(
f"""
SELECT added_data_files_count, existing_data_files_count, deleted_data_files_count
FROM {identifier}.all_manifests
"""
).collect()

assert [row.added_data_files_count for row in rows] == [5]
assert [row.existing_data_files_count for row in rows] == [0]
assert [row.deleted_data_files_count for row in rows] == [0]

df = spark.table(identifier)
assert df.count() == 5, "Expected 5 rows"
for col in df.columns:
assert df.filter(df[col].isNotNull()).count() == 5, "Expected all 5 rows to be non-null"

0 comments on commit e250ffc

Please sign in to comment.