Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race condition on Table.scan with limit #545

Merged
merged 1 commit into from
Mar 25, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 4 additions & 14 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,13 +946,9 @@ def _task_to_table(
projected_field_ids: Set[int],
positional_deletes: Optional[List[ChunkedArray]],
case_sensitive: bool,
row_counts: List[int],
limit: Optional[int] = None,
name_mapping: Optional[NameMapping] = None,
) -> Optional[pa.Table]:
if limit and sum(row_counts) >= limit:
return None

_, _, path = PyArrowFileIO.parse_location(task.file.file_path)
arrow_format = ds.ParquetFileFormat(pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8))
with fs.open_input_file(path) as fin:
Expand Down Expand Up @@ -1015,11 +1011,6 @@ def _task_to_table(
if len(arrow_table) < 1:
return None

if limit is not None and sum(row_counts) >= limit:
return None

row_counts.append(len(arrow_table))

return to_requested_schema(projected_schema, file_project_schema, arrow_table)


Expand Down Expand Up @@ -1085,7 +1076,6 @@ def project_table(
id for id in projected_schema.field_ids if not isinstance(projected_schema.find_type(id), (MapType, ListType))
}.union(extract_field_ids(bound_row_filter))

row_counts: List[int] = []
deletes_per_file = _read_all_delete_files(fs, tasks)
executor = ExecutorFactory.get_or_create()
futures = [
Expand All @@ -1098,21 +1088,21 @@ def project_table(
projected_field_ids,
deletes_per_file.get(task.file.file_path),
case_sensitive,
row_counts,
limit,
table.name_mapping(),
)
for task in tasks
]

total_row_count = 0
# for consistent ordering, we need to maintain future order
futures_index = {f: i for i, f in enumerate(futures)}
completed_futures: SortedList[Future[pa.Table]] = SortedList(iterable=[], key=lambda f: futures_index[f])
for future in concurrent.futures.as_completed(futures):
completed_futures.add(future)

if table_result := future.result():
total_row_count += len(table_result)
# stop early if limit is satisfied
if limit is not None and sum(row_counts) >= limit:
if limit is not None and total_row_count >= limit:
break

# by now, we've either completed all tasks or satisfied the limit
Expand Down