Skip to content

Commit

Permalink
add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinjqliu committed Mar 6, 2024
1 parent 6cf617d commit 7123b9f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 11 deletions.
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1965,6 +1965,7 @@ def spark() -> SparkSession:
.config("spark.sql.extensions", "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions")
.config("spark.sql.catalog.integration", "org.apache.iceberg.spark.SparkCatalog")
.config("spark.sql.catalog.integration.catalog-impl", "org.apache.iceberg.rest.RESTCatalog")
.config("spark.sql.catalog.integration.cache-enabled", "false")
.config("spark.sql.catalog.integration.uri", "http://localhost:8181")
.config("spark.sql.catalog.integration.io-impl", "org.apache.iceberg.aws.s3.S3FileIO")
.config("spark.sql.catalog.integration.warehouse", "s3://warehouse/wh/")
Expand Down
41 changes: 30 additions & 11 deletions tests/integration/test_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from pyiceberg.catalog.sql import SqlCatalog
from pyiceberg.exceptions import NoSuchTableError
from pyiceberg.schema import Schema
from pyiceberg.table import TableProperties, _dataframe_to_data_files
from pyiceberg.table import SetPropertiesUpdate, TableProperties, _dataframe_to_data_files
from pyiceberg.types import (
BinaryType,
BooleanType,
Expand Down Expand Up @@ -356,31 +356,50 @@ def test_data_files(spark: SparkSession, session_catalog: Catalog, arrow_table_w


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_write_multiple_data_files(
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int
) -> None:
identifier = "default.write_multiple_arrow_data_files"
tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, [])
def test_write_bin_pack_data_files(spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:
identifier = "default.write_bin_pack_data_files"
tbl = _create_table(session_catalog, identifier, {"format-version": "1"}, [])

def get_data_files_count(identifier: str) -> int:
return spark.sql(
f"""
SELECT *
FROM {identifier}.all_data_files
FROM {identifier}.files
"""
).count()

# writes to 1 data file since the table is small
def set_table_properties(tbl: Table, properties: Properties) -> Table:
with tbl.transaction() as transaction:
transaction._apply((SetPropertiesUpdate(updates=properties),))
return tbl

# writes 1 data file since the table is smaller than default target file size
assert arrow_table_with_null.nbytes < TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT
tbl.overwrite(arrow_table_with_null)
assert get_data_files_count(identifier) == 1

# writes to 1 data file as long as table is smaller than default target file size
# writes 1 data file as long as table is smaller than default target file size
bigger_arrow_tbl = pa.concat_tables([arrow_table_with_null] * 10)
tbl.overwrite(bigger_arrow_tbl)
assert bigger_arrow_tbl.nbytes < TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT
tbl.overwrite(bigger_arrow_tbl)
assert get_data_files_count(identifier) == 1

# writes multiple data files once target file size is overridden
target_file_size = arrow_table_with_null.nbytes
tbl = set_table_properties(tbl, {TableProperties.WRITE_TARGET_FILE_SIZE_BYTES: str(target_file_size)})
assert str(target_file_size) == tbl.properties.get(TableProperties.WRITE_TARGET_FILE_SIZE_BYTES)
assert target_file_size < bigger_arrow_tbl.nbytes
tbl.overwrite(bigger_arrow_tbl)
assert get_data_files_count(identifier) == 10

# writes half the number of data files when target file size doubles
target_file_size = arrow_table_with_null.nbytes * 2
tbl = set_table_properties(tbl, {TableProperties.WRITE_TARGET_FILE_SIZE_BYTES: str(target_file_size)})
assert str(target_file_size) == tbl.properties.get(TableProperties.WRITE_TARGET_FILE_SIZE_BYTES)
assert target_file_size < bigger_arrow_tbl.nbytes
tbl.overwrite(bigger_arrow_tbl)
assert get_data_files_count(identifier) == 5


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
Expand Down

0 comments on commit 7123b9f

Please sign in to comment.