Skip to content

Commit

Permalink
[bug] fix reading with to_arrow_batch_reader and limit (#1042)
Browse files Browse the repository at this point in the history
* fix project_batches with limit

* add test

* lint + readability
  • Loading branch information
kevinjqliu authored Aug 12, 2024
1 parent 2e73a41 commit f05b1ae
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 2 deletions.
8 changes: 6 additions & 2 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1455,6 +1455,9 @@ def project_batches(
total_row_count = 0

for task in tasks:
# stop early if limit is satisfied
if limit is not None and total_row_count >= limit:
break
batches = _task_to_record_batches(
fs,
task,
Expand All @@ -1468,9 +1471,10 @@ def project_batches(
)
for batch in batches:
if limit is not None:
if total_row_count + len(batch) >= limit:
yield batch.slice(0, limit - total_row_count)
if total_row_count >= limit:
break
elif total_row_count + len(batch) >= limit:
batch = batch.slice(0, limit - total_row_count)
yield batch
total_row_count += len(batch)

Expand Down
48 changes: 48 additions & 0 deletions tests/integration/test_reads.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,54 @@ def test_pyarrow_limit(catalog: Catalog) -> None:
full_result = table_test_limit.scan(selected_fields=("idx",), limit=999).to_arrow()
assert len(full_result) == 10

# test `to_arrow_batch_reader`
limited_result = table_test_limit.scan(selected_fields=("idx",), limit=1).to_arrow_batch_reader().read_all()
assert len(limited_result) == 1

empty_result = table_test_limit.scan(selected_fields=("idx",), limit=0).to_arrow_batch_reader().read_all()
assert len(empty_result) == 0

full_result = table_test_limit.scan(selected_fields=("idx",), limit=999).to_arrow_batch_reader().read_all()
assert len(full_result) == 10


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
def test_pyarrow_limit_with_multiple_files(catalog: Catalog) -> None:
table_name = "default.test_pyarrow_limit_with_multiple_files"
try:
catalog.drop_table(table_name)
except NoSuchTableError:
pass
reference_table = catalog.load_table("default.test_limit")
data = reference_table.scan().to_arrow()
table_test_limit = catalog.create_table(table_name, schema=reference_table.schema())

n_files = 2
for _ in range(n_files):
table_test_limit.append(data)
assert len(table_test_limit.inspect.files()) == n_files

# test with multiple files
limited_result = table_test_limit.scan(selected_fields=("idx",), limit=1).to_arrow()
assert len(limited_result) == 1

empty_result = table_test_limit.scan(selected_fields=("idx",), limit=0).to_arrow()
assert len(empty_result) == 0

full_result = table_test_limit.scan(selected_fields=("idx",), limit=999).to_arrow()
assert len(full_result) == 10 * n_files

# test `to_arrow_batch_reader`
limited_result = table_test_limit.scan(selected_fields=("idx",), limit=1).to_arrow_batch_reader().read_all()
assert len(limited_result) == 1

empty_result = table_test_limit.scan(selected_fields=("idx",), limit=0).to_arrow_batch_reader().read_all()
assert len(empty_result) == 0

full_result = table_test_limit.scan(selected_fields=("idx",), limit=999).to_arrow_batch_reader().read_all()
assert len(full_result) == 10 * n_files


@pytest.mark.integration
@pytest.mark.filterwarnings("ignore")
Expand Down

0 comments on commit f05b1ae

Please sign in to comment.