Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use write.parquet.compression-{codec,level} #358

Merged
merged 9 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyiceberg/catalog/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,10 @@ def create_table(
iceberg_schema = self._convert_schema_if_needed(schema)
iceberg_schema = assign_fresh_schema_ids(iceberg_schema)

properties = properties.copy()
jonashaag marked this conversation as resolved.
Show resolved Hide resolved
for copy_key in ["write.parquet.compression-codec", "write.parquet.compression-level"]:
if copy_key in self.properties:
properties[copy_key] = self.properties[copy_key]
jonashaag marked this conversation as resolved.
Show resolved Hide resolved
namespace_and_table = self._split_identifier_for_path(identifier)
request = CreateTableRequest(
name=namespace_and_table["table"],
Expand Down
20 changes: 13 additions & 7 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1720,13 +1720,23 @@ def write_file(table: Table, tasks: Iterator[WriteTask]) -> Iterator[DataFile]:
except StopIteration:
pass

compression_codec = table.properties.get("write.parquet.compression-codec")
compression_level = table.properties.get("write.parquet.compression-level")
compression_options: Dict[str, Any]
if compression_codec == "uncompressed":
jonashaag marked this conversation as resolved.
Show resolved Hide resolved
compression_options = {"compression": "none"}
else:
jonashaag marked this conversation as resolved.
Show resolved Hide resolved
compression_options = {
"compression": compression_codec,
"compression_level": None if compression_level is None else int(compression_level),
}

file_path = f'{table.location()}/data/{task.generate_data_file_filename("parquet")}'
file_schema = schema_to_pyarrow(table.schema())

collected_metrics: List[pq.FileMetaData] = []
fo = table.io.new_output(file_path)
with fo.create(overwrite=True) as fos:
with pq.ParquetWriter(fos, schema=file_schema, version="1.0", metadata_collector=collected_metrics) as writer:
with pq.ParquetWriter(fos, schema=file_schema, version="1.0", **compression_options) as writer:
writer.write_table(task.df)

data_file = DataFile(
Expand All @@ -1745,13 +1755,9 @@ def write_file(table: Table, tasks: Iterator[WriteTask]) -> Iterator[DataFile]:
key_metadata=None,
)

if len(collected_metrics) != 1:
# One file has been written
raise ValueError(f"Expected 1 entry, got: {collected_metrics}")

fill_parquet_file_metadata(
data_file=data_file,
parquet_metadata=collected_metrics[0],
parquet_metadata=writer.writer.metadata,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked this through the debugger, and this looks good. Nice change @jonashaag 👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can also tell from the PyArrow code that it's identical :)

stats_columns=compute_statistics_plan(table.schema(), table.properties),
parquet_column_mapping=parquet_path_to_id_mapping(table.schema()),
)
Expand Down
6 changes: 5 additions & 1 deletion tests/integration/test_reads.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,11 @@ def test_ray_all_types(catalog: Catalog) -> None:
@pytest.mark.parametrize('catalog', [pytest.lazy_fixture('catalog_hive'), pytest.lazy_fixture('catalog_rest')])
def test_pyarrow_to_iceberg_all_types(catalog: Catalog) -> None:
table_test_all_types = catalog.load_table("default.test_all_types")
fs = S3FileSystem(endpoint_override="http://localhost:9000", access_key="admin", secret_key="password")
fs = S3FileSystem(
endpoint_override=catalog.properties["s3.endpoint"],
access_key=catalog.properties["s3.access-key-id"],
secret_key=catalog.properties["s3.secret-access-key"],
)
data_file_paths = [task.file.file_path for task in table_test_all_types.scan().plan_files()]
for data_file_path in data_file_paths:
uri = urlparse(data_file_path)
Expand Down
55 changes: 55 additions & 0 deletions tests/integration/test_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
# pylint:disable=redefined-outer-name
import uuid
from datetime import date, datetime
from urllib.parse import urlparse

import pyarrow as pa
import pyarrow.parquet as pq
import pytest
from pyarrow.fs import S3FileSystem
from pyspark.sql import SparkSession

from pyiceberg.catalog import Catalog, load_catalog
Expand Down Expand Up @@ -489,6 +492,58 @@ def test_data_files(spark: SparkSession, session_catalog: Catalog, arrow_table_w
assert [row.deleted_data_files_count for row in rows] == [0, 0, 1, 0, 0]


@pytest.mark.integration
@pytest.mark.parametrize(
"compression",
# List of (compression_properties, expected_compression_name)
[
# REST catalog uses Zstandard by default: https://github.com/apache/iceberg/pull/8593
({}, "ZSTD"),
({"write.parquet.compression-codec": "uncompressed"}, "UNCOMPRESSED"),
({"write.parquet.compression-codec": "gzip", "write.parquet.compression-level": "1"}, "GZIP"),
({"write.parquet.compression-codec": "zstd", "write.parquet.compression-level": "1"}, "ZSTD"),
({"write.parquet.compression-codec": "snappy"}, "SNAPPY"),
],
)
def test_parquet_compression(spark: SparkSession, arrow_table_with_null: pa.Table, compression) -> None:
compression_properties, expected_compression_name = compression

catalog = load_catalog(
"local",
**{
"type": "rest",
"uri": "http://localhost:8181",
"s3.endpoint": "http://localhost:9000",
"s3.access-key-id": "admin",
"s3.secret-access-key": "password",
**compression_properties,
},
jonashaag marked this conversation as resolved.
Show resolved Hide resolved
)
identifier = "default.arrow_data_files"

try:
catalog.drop_table(identifier=identifier)
except NoSuchTableError:
pass
tbl = catalog.create_table(identifier=identifier, schema=TABLE_SCHEMA, properties={'format-version': '1'})
jonashaag marked this conversation as resolved.
Show resolved Hide resolved

tbl.overwrite(arrow_table_with_null)

data_file_paths = [task.file.file_path for task in tbl.scan().plan_files()]

fs = S3FileSystem(
endpoint_override=catalog.properties["s3.endpoint"],
access_key=catalog.properties["s3.access-key-id"],
secret_key=catalog.properties["s3.secret-access-key"],
)
uri = urlparse(data_file_paths[0])
with fs.open_input_file(f"{uri.netloc}{uri.path}") as f:
parquet_metadata = pq.read_metadata(f)
compression = parquet_metadata.row_group(0).column(0).compression

assert compression == expected_compression_name


@pytest.mark.integration
def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:
identifier = "default.arrow_data_files"
Expand Down