Skip to content

Commit

Permalink
Fix race condition on Table.scan with limit (apache#545)
Browse files Browse the repository at this point in the history
(cherry picked from commit 0d12cf4)
  • Loading branch information
kevinjqliu authored and HonahX committed Mar 28, 2024
1 parent 9f8c746 commit c069831
Showing 1 changed file with 4 additions and 14 deletions.
18 changes: 4 additions & 14 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,13 +938,9 @@ def _task_to_table(
projected_field_ids: Set[int],
positional_deletes: Optional[List[ChunkedArray]],
case_sensitive: bool,
row_counts: List[int],
limit: Optional[int] = None,
name_mapping: Optional[NameMapping] = None,
) -> Optional[pa.Table]:
if limit and sum(row_counts) >= limit:
return None

_, _, path = PyArrowFileIO.parse_location(task.file.file_path)
arrow_format = ds.ParquetFileFormat(pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8))
with fs.open_input_file(path) as fin:
Expand Down Expand Up @@ -1007,11 +1003,6 @@ def _task_to_table(
if len(arrow_table) < 1:
return None

if limit is not None and sum(row_counts) >= limit:
return None

row_counts.append(len(arrow_table))

return to_requested_schema(projected_schema, file_project_schema, arrow_table)


Expand Down Expand Up @@ -1077,7 +1068,6 @@ def project_table(
id for id in projected_schema.field_ids if not isinstance(projected_schema.find_type(id), (MapType, ListType))
}.union(extract_field_ids(bound_row_filter))

row_counts: List[int] = []
deletes_per_file = _read_all_delete_files(fs, tasks)
executor = ExecutorFactory.get_or_create()
futures = [
Expand All @@ -1090,21 +1080,21 @@ def project_table(
projected_field_ids,
deletes_per_file.get(task.file.file_path),
case_sensitive,
row_counts,
limit,
table.name_mapping(),
)
for task in tasks
]

total_row_count = 0
# for consistent ordering, we need to maintain future order
futures_index = {f: i for i, f in enumerate(futures)}
completed_futures: SortedList[Future[pa.Table]] = SortedList(iterable=[], key=lambda f: futures_index[f])
for future in concurrent.futures.as_completed(futures):
completed_futures.add(future)

if table_result := future.result():
total_row_count += len(table_result)
# stop early if limit is satisfied
if limit is not None and sum(row_counts) >= limit:
if limit is not None and total_row_count >= limit:
break

# by now, we've either completed all tasks or satisfied the limit
Expand Down

0 comments on commit c069831

Please sign in to comment.