Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 78 additions & 164 deletions libs/libcommon/src/libcommon/parquet_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,30 @@ def is_list_pa_type(parquet_file_path: Path, feature_name: str) -> bool:
return is_list


def truncate_binary_columns(table: pa.Table, max_binary_length: int, features: Features) -> tuple[pa.Table, list[str]]:
Copy link
Member Author

@kszucs kszucs Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I applied this change in the libviewer PR as well but pulling this out here to make the reviewer process easier.

# truncate binary columns in the Arrow table to the specified maximum length
# return a new Arrow table and the list of truncated columns
if max_binary_length < 0:
return table, []

columns: dict[str, pa.Array] = {}
truncated_column_names: list[str] = []
for field_idx, field in enumerate(table.schema): # noqa: F402
if features[field.name] == Value("binary") and table[field_idx].nbytes > max_binary_length:
truncated_array = pc.binary_slice(table[field_idx], 0, max_binary_length // len(table))
columns[field.name] = truncated_array
truncated_column_names.append(field.name)
else:
columns[field.name] = table[field_idx]

return pa.table(columns), truncated_column_names


@dataclass
class RowGroupReader:
parquet_file: pq.ParquetFile
group_id: int
features: Features
schema: pa.Schema

def read(self, columns: list[str]) -> pa.Table:
if not set(self.parquet_file.schema_arrow.names) <= set(columns):
Expand All @@ -151,18 +170,7 @@ def read(self, columns: list[str]) -> pa.Table:
)
pa_table = self.parquet_file.read_row_group(i=self.group_id, columns=columns)
# cast_table_to_schema adds null values to missing columns
return cast_table_to_schema(pa_table, self.features.arrow_schema)

def read_truncated_binary(self, columns: list[str], max_binary_length: int) -> tuple[pa.Table, list[str]]:
pa_table = self.parquet_file.read_row_group(i=self.group_id, columns=columns)
truncated_columns: list[str] = []
if max_binary_length:
for field_idx, field in enumerate(pa_table.schema):
if self.features[field.name] == Value("binary") and pa_table[field_idx].nbytes > max_binary_length:
truncated_array = pc.binary_slice(pa_table[field_idx], 0, max_binary_length // len(pa_table))
pa_table = pa_table.set_column(field_idx, field, truncated_array)
truncated_columns.append(field.name)
return cast_table_to_schema(pa_table, self.features.arrow_schema), truncated_columns
return cast_table_to_schema(pa_table, self.schema)

def read_size(self, columns: Optional[Iterable[str]] = None) -> int:
if columns is None:
Expand All @@ -179,32 +187,33 @@ def read_size(self, columns: Optional[Iterable[str]] = None) -> int:

@dataclass
class ParquetIndexWithMetadata:
files: list[ParquetFileMetadataItem]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Storing the list of ParquetFileMetadataItem here directly since the previously extracted field lists didn't provide any advantage but rather noise.

features: Features
parquet_files_urls: list[str]
metadata_paths: list[str]
num_bytes: list[int]
num_rows: list[int]
httpfs: HTTPFileSystem
max_arrow_data_in_memory: int
partial: bool
metadata_dir: Path

file_offsets: np.ndarray = field(init=False)
num_rows_total: int = field(init=False)

def __post_init__(self) -> None:
if self.httpfs._session is None:
self.httpfs_session = asyncio.run(self.httpfs.set_session())
else:
self.httpfs_session = self.httpfs._session
self.num_rows_total = sum(self.num_rows)

def query_truncated_binary(self, offset: int, length: int) -> tuple[pa.Table, list[str]]:
num_rows = np.array([f["num_rows"] for f in self.files])
self.file_offsets = np.cumsum(num_rows)
self.num_rows_total = np.sum(num_rows)

def query(self, offset: int, length: int) -> tuple[pa.Table, list[str]]:
"""Query the parquet files

Note that this implementation will always read at least one row group, to get the list of columns and always
have the same schema, even if the requested rows are invalid (out of range).

This is the same as query() except that:

If binary columns are present, then:
- it computes a maximum size to allocate to binary data in step "parquet_index_with_metadata.row_groups_size_check_truncated_binary"
- it uses `read_truncated_binary()` in step "parquet_index_with_metadata.query_truncated_binary".

Expand All @@ -219,27 +228,19 @@ def query_truncated_binary(self, offset: int, length: int) -> tuple[pa.Table, li
`pa.Table`: The requested rows.
`list[strl]: List of truncated columns.
"""
all_columns = set(self.features)
binary_columns = set(column for column, feature in self.features.items() if feature == Value("binary"))
if not binary_columns:
return self.query(offset=offset, length=length), []
with StepProfiler(
method="parquet_index_with_metadata.query", step="get the parquet files than contain the requested rows"
):
parquet_file_offsets = np.cumsum(self.num_rows)

last_row_in_parquet = parquet_file_offsets[-1] - 1
last_row_in_parquet = self.file_offsets[-1] - 1
first_row = min(offset, last_row_in_parquet)
last_row = min(offset + length - 1, last_row_in_parquet)
first_parquet_file_id, last_parquet_file_id = np.searchsorted(
parquet_file_offsets, [first_row, last_row], side="right"
self.file_offsets, [first_row, last_row], side="right"
)
parquet_offset = (
offset - parquet_file_offsets[first_parquet_file_id - 1] if first_parquet_file_id > 0 else offset
offset - self.file_offsets[first_parquet_file_id - 1] if first_parquet_file_id > 0 else offset
)
urls = self.parquet_files_urls[first_parquet_file_id : last_parquet_file_id + 1] # noqa: E203
metadata_paths = self.metadata_paths[first_parquet_file_id : last_parquet_file_id + 1] # noqa: E203
num_bytes = self.num_bytes[first_parquet_file_id : last_parquet_file_id + 1] # noqa: E203
files_to_scan = self.files[first_parquet_file_id : last_parquet_file_id + 1] # noqa: E203

with StepProfiler(
method="parquet_index_with_metadata.query", step="load the remote parquet files using metadata from disk"
Expand All @@ -248,17 +249,17 @@ def query_truncated_binary(self, offset: int, length: int) -> tuple[pa.Table, li
pq.ParquetFile(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pq.ParquetFile can accept a filesystem argument so I think it could be further simplified and the mocking in the tests could be simpler by passing a local filesystem rather than a network filesystem but leaving it as is since hopefully we can remove it entirely.

HTTPFile(
self.httpfs,
url,
f["url"],
session=self.httpfs_session,
size=size,
size=f["size"],
loop=self.httpfs.loop,
cache_type=None,
**self.httpfs.kwargs,
),
metadata=pq.read_metadata(metadata_path),
metadata=pq.read_metadata(self.metadata_dir / f["parquet_metadata_subpath"]),
pre_buffer=True,
)
for url, metadata_path, size in zip(urls, metadata_paths, num_bytes)
for f in files_to_scan
]

with StepProfiler(
Expand All @@ -272,7 +273,7 @@ def query_truncated_binary(self, offset: int, length: int) -> tuple[pa.Table, li
]
)
row_group_readers = [
RowGroupReader(parquet_file=parquet_file, group_id=group_id, features=self.features)
RowGroupReader(parquet_file=parquet_file, group_id=group_id, schema=self.features.arrow_schema)
for parquet_file in parquet_files
for group_id in range(parquet_file.metadata.num_row_groups)
]
Expand All @@ -290,6 +291,28 @@ def query_truncated_binary(self, offset: int, length: int) -> tuple[pa.Table, li
row_group_offsets, [first_row, last_row], side="right"
)

all_columns = set(self.features)
binary_columns = set(column for column, feature in self.features.items() if feature == Value("binary"))
if binary_columns:
pa_table, truncated_columns = self._read_with_binary(
row_group_readers, first_row_group_id, last_row_group_id, all_columns, binary_columns
)
else:
pa_table, truncated_columns = self._read_without_binary(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Theoretically we could remove this branch as well, at most there would be negligible performance overhead, but trying to keep the old behavior.

row_group_readers, first_row_group_id, last_row_group_id
)

first_row_in_pa_table = row_group_offsets[first_row_group_id - 1] if first_row_group_id > 0 else 0
return pa_table.slice(parquet_offset - first_row_in_pa_table, length), truncated_columns

def _read_with_binary(
self,
row_group_readers: list[RowGroupReader],
first_row_group_id: int,
last_row_group_id: int,
all_columns: set[str],
binary_columns: set[str],
) -> tuple[pa.Table, list[str]]:
with StepProfiler(
method="parquet_index_with_metadata.row_groups_size_check_truncated_binary",
step="check if the rows can fit in memory",
Expand Down Expand Up @@ -329,100 +352,21 @@ def query_truncated_binary(self, offset: int, length: int) -> tuple[pa.Table, li
columns = list(self.features.keys())
truncated_columns: set[str] = set()
for i in range(first_row_group_id, last_row_group_id + 1):
rg_pa_table, rg_truncated_columns = row_group_readers[i].read_truncated_binary(
columns, max_binary_length=max_binary_length
rg_pa_table = row_group_readers[i].read(columns)
rg_pa_table, rg_truncated_columns = truncate_binary_columns(
rg_pa_table, max_binary_length, self.features
)
pa_tables.append(rg_pa_table)
truncated_columns |= set(rg_truncated_columns)
pa_table = pa.concat_tables(pa_tables)
except ArrowInvalid as err:
raise SchemaMismatchError("Parquet files have different schema.", err)
first_row_in_pa_table = row_group_offsets[first_row_group_id - 1] if first_row_group_id > 0 else 0
return pa_table.slice(parquet_offset - first_row_in_pa_table, length), list(truncated_columns)

def query(self, offset: int, length: int) -> pa.Table:
"""Query the parquet files

Note that this implementation will always read at least one row group, to get the list of columns and always
have the same schema, even if the requested rows are invalid (out of range).

Args:
offset (`int`): The first row to read.
length (`int`): The number of rows to read.

Raises:
[`TooBigRows`]: if the arrow data from the parquet row groups is bigger than max_arrow_data_in_memory

Returns:
`pa.Table`: The requested rows.
"""
with StepProfiler(
method="parquet_index_with_metadata.query", step="get the parquet files than contain the requested rows"
):
parquet_file_offsets = np.cumsum(self.num_rows)

last_row_in_parquet = parquet_file_offsets[-1] - 1
first_row = min(offset, last_row_in_parquet)
last_row = min(offset + length - 1, last_row_in_parquet)
first_parquet_file_id, last_parquet_file_id = np.searchsorted(
parquet_file_offsets, [first_row, last_row], side="right"
)
parquet_offset = (
offset - parquet_file_offsets[first_parquet_file_id - 1] if first_parquet_file_id > 0 else offset
)
urls = self.parquet_files_urls[first_parquet_file_id : last_parquet_file_id + 1] # noqa: E203
metadata_paths = self.metadata_paths[first_parquet_file_id : last_parquet_file_id + 1] # noqa: E203
num_bytes = self.num_bytes[first_parquet_file_id : last_parquet_file_id + 1] # noqa: E203

with StepProfiler(
method="parquet_index_with_metadata.query", step="load the remote parquet files using metadata from disk"
):
parquet_files = [
pq.ParquetFile(
HTTPFile(
self.httpfs,
url,
session=self.httpfs_session,
size=size,
loop=self.httpfs.loop,
cache_type=None,
**self.httpfs.kwargs,
),
metadata=pq.read_metadata(metadata_path),
pre_buffer=True,
)
for url, metadata_path, size in zip(urls, metadata_paths, num_bytes)
]

with StepProfiler(
method="parquet_index_with_metadata.query", step="get the row groups than contain the requested rows"
):
row_group_offsets = np.cumsum(
[
parquet_file.metadata.row_group(group_id).num_rows
for parquet_file in parquet_files
for group_id in range(parquet_file.metadata.num_row_groups)
]
)
row_group_readers = [
RowGroupReader(parquet_file=parquet_file, group_id=group_id, features=self.features)
for parquet_file in parquet_files
for group_id in range(parquet_file.metadata.num_row_groups)
]

if len(row_group_offsets) == 0 or row_group_offsets[-1] == 0: # if the dataset is empty
if offset < 0:
raise IndexError("Offset must be non-negative")
return cast_table_to_schema(parquet_files[0].read(), self.features.arrow_schema)

last_row_in_parquet = row_group_offsets[-1] - 1
first_row = min(parquet_offset, last_row_in_parquet)
last_row = min(parquet_offset + length - 1, last_row_in_parquet)

first_row_group_id, last_row_group_id = np.searchsorted(
row_group_offsets, [first_row, last_row], side="right"
)
return pa_table, list(truncated_columns)

def _read_without_binary(
self, row_group_readers: list[RowGroupReader], first_row_group_id: int, last_row_group_id: int
) -> tuple[pa.Table, list[str]]:
with StepProfiler(
method="parquet_index_with_metadata.row_groups_size_check", step="check if the rows can fit in memory"
):
Expand All @@ -443,8 +387,8 @@ def query(self, offset: int, length: int) -> pa.Table:
)
except ArrowInvalid as err:
raise SchemaMismatchError("Parquet files have different schema.", err)
first_row_in_pa_table = row_group_offsets[first_row_group_id - 1] if first_row_group_id > 0 else 0
return pa_table.slice(parquet_offset - first_row_in_pa_table, length)

return pa_table, []

@staticmethod
def from_parquet_metadata_items(
Expand All @@ -458,40 +402,31 @@ def from_parquet_metadata_items(
raise EmptyParquetMetadataError("No parquet files found.")

partial = parquet_export_is_partial(parquet_file_metadata_items[0]["url"])
metadata_dir = Path(parquet_metadata_directory)

with StepProfiler(
method="parquet_index_with_metadata.from_parquet_metadata_items",
step="get the index from parquet metadata",
):
try:
parquet_files_metadata = sorted(
parquet_file_metadata_items, key=lambda parquet_file_metadata: parquet_file_metadata["filename"]
)
parquet_files_urls = [parquet_file_metadata["url"] for parquet_file_metadata in parquet_files_metadata]
metadata_paths = [
os.path.join(parquet_metadata_directory, parquet_file_metadata["parquet_metadata_subpath"])
for parquet_file_metadata in parquet_files_metadata
]
num_bytes = [parquet_file_metadata["size"] for parquet_file_metadata in parquet_files_metadata]
num_rows = [parquet_file_metadata["num_rows"] for parquet_file_metadata in parquet_files_metadata]
files = sorted(parquet_file_metadata_items, key=lambda f: f["filename"])
except Exception as e:
raise ParquetResponseFormatError(f"Could not parse the list of parquet files: {e}") from e

with StepProfiler(
method="parquet_index_with_metadata.from_parquet_metadata_items", step="get the dataset's features"
):
if features is None: # config-parquet version<6 didn't have features
features = Features.from_arrow_schema(pq.read_schema(metadata_paths[0]))
first_arrow_schema = pq.read_schema(metadata_dir / files[0]["parquet_metadata_subpath"])
features = Features.from_arrow_schema(first_arrow_schema)

return ParquetIndexWithMetadata(
files=files,
features=features,
parquet_files_urls=parquet_files_urls,
metadata_paths=metadata_paths,
num_bytes=num_bytes,
num_rows=num_rows,
httpfs=httpfs,
max_arrow_data_in_memory=max_arrow_data_in_memory,
partial=partial,
metadata_dir=metadata_dir,
)


Expand Down Expand Up @@ -551,28 +486,7 @@ def _init_parquet_index(

# note that this cache size is global for the class, not per instance
@lru_cache(maxsize=1)
def query(self, offset: int, length: int) -> pa.Table:
"""Query the parquet files

Note that this implementation will always read at least one row group, to get the list of columns and always
have the same schema, even if the requested rows are invalid (out of range).

Args:
offset (`int`): The first row to read.
length (`int`): The number of rows to read.

Returns:
`pa.Table`: The requested rows.
"""
logging.info(
f"Query {type(self.parquet_index).__name__} for dataset={self.dataset}, config={self.config},"
f" split={self.split}, offset={offset}, length={length}"
)
return self.parquet_index.query(offset=offset, length=length)

# note that this cache size is global for the class, not per instance
@lru_cache(maxsize=1)
def query_truncated_binary(self, offset: int, length: int) -> tuple[pa.Table, list[str]]:
def query(self, offset: int, length: int) -> tuple[pa.Table, list[str]]:
"""Query the parquet files

Note that this implementation will always read at least one row group, to get the list of columns and always
Expand All @@ -590,4 +504,4 @@ def query_truncated_binary(self, offset: int, length: int) -> tuple[pa.Table, li
f"Query {type(self.parquet_index).__name__} for dataset={self.dataset}, config={self.config},"
f" split={self.split}, offset={offset}, length={length}, with truncated binary"
)
return self.parquet_index.query_truncated_binary(offset=offset, length=length)
return self.parquet_index.query(offset=offset, length=length)
Loading
Loading