-
Notifications
You must be signed in to change notification settings - Fork 95
refactor(libcommon): consolidate query() and query_truncated_binary() methods
#3253
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9eccc2c
f0368ea
0713677
d03d13f
cdd2a25
eee8d64
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -138,11 +138,30 @@ def is_list_pa_type(parquet_file_path: Path, feature_name: str) -> bool: | |
| return is_list | ||
|
|
||
|
|
||
| def truncate_binary_columns(table: pa.Table, max_binary_length: int, features: Features) -> tuple[pa.Table, list[str]]: | ||
| # truncate binary columns in the Arrow table to the specified maximum length | ||
| # return a new Arrow table and the list of truncated columns | ||
| if max_binary_length < 0: | ||
| return table, [] | ||
|
|
||
| columns: dict[str, pa.Array] = {} | ||
| truncated_column_names: list[str] = [] | ||
| for field_idx, field in enumerate(table.schema): # noqa: F402 | ||
| if features[field.name] == Value("binary") and table[field_idx].nbytes > max_binary_length: | ||
| truncated_array = pc.binary_slice(table[field_idx], 0, max_binary_length // len(table)) | ||
| columns[field.name] = truncated_array | ||
| truncated_column_names.append(field.name) | ||
| else: | ||
| columns[field.name] = table[field_idx] | ||
|
|
||
| return pa.table(columns), truncated_column_names | ||
|
|
||
|
|
||
| @dataclass | ||
| class RowGroupReader: | ||
| parquet_file: pq.ParquetFile | ||
| group_id: int | ||
| features: Features | ||
| schema: pa.Schema | ||
|
|
||
| def read(self, columns: list[str]) -> pa.Table: | ||
| if not set(self.parquet_file.schema_arrow.names) <= set(columns): | ||
|
|
@@ -151,18 +170,7 @@ def read(self, columns: list[str]) -> pa.Table: | |
| ) | ||
| pa_table = self.parquet_file.read_row_group(i=self.group_id, columns=columns) | ||
| # cast_table_to_schema adds null values to missing columns | ||
| return cast_table_to_schema(pa_table, self.features.arrow_schema) | ||
|
|
||
| def read_truncated_binary(self, columns: list[str], max_binary_length: int) -> tuple[pa.Table, list[str]]: | ||
| pa_table = self.parquet_file.read_row_group(i=self.group_id, columns=columns) | ||
| truncated_columns: list[str] = [] | ||
| if max_binary_length: | ||
| for field_idx, field in enumerate(pa_table.schema): | ||
| if self.features[field.name] == Value("binary") and pa_table[field_idx].nbytes > max_binary_length: | ||
| truncated_array = pc.binary_slice(pa_table[field_idx], 0, max_binary_length // len(pa_table)) | ||
| pa_table = pa_table.set_column(field_idx, field, truncated_array) | ||
| truncated_columns.append(field.name) | ||
| return cast_table_to_schema(pa_table, self.features.arrow_schema), truncated_columns | ||
| return cast_table_to_schema(pa_table, self.schema) | ||
|
|
||
| def read_size(self, columns: Optional[Iterable[str]] = None) -> int: | ||
| if columns is None: | ||
|
|
@@ -179,32 +187,33 @@ def read_size(self, columns: Optional[Iterable[str]] = None) -> int: | |
|
|
||
| @dataclass | ||
| class ParquetIndexWithMetadata: | ||
| files: list[ParquetFileMetadataItem] | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Storing the list of |
||
| features: Features | ||
| parquet_files_urls: list[str] | ||
| metadata_paths: list[str] | ||
| num_bytes: list[int] | ||
| num_rows: list[int] | ||
| httpfs: HTTPFileSystem | ||
| max_arrow_data_in_memory: int | ||
| partial: bool | ||
| metadata_dir: Path | ||
|
|
||
| file_offsets: np.ndarray = field(init=False) | ||
| num_rows_total: int = field(init=False) | ||
|
|
||
| def __post_init__(self) -> None: | ||
| if self.httpfs._session is None: | ||
| self.httpfs_session = asyncio.run(self.httpfs.set_session()) | ||
| else: | ||
| self.httpfs_session = self.httpfs._session | ||
| self.num_rows_total = sum(self.num_rows) | ||
|
|
||
| def query_truncated_binary(self, offset: int, length: int) -> tuple[pa.Table, list[str]]: | ||
| num_rows = np.array([f["num_rows"] for f in self.files]) | ||
| self.file_offsets = np.cumsum(num_rows) | ||
| self.num_rows_total = np.sum(num_rows) | ||
|
|
||
| def query(self, offset: int, length: int) -> tuple[pa.Table, list[str]]: | ||
| """Query the parquet files | ||
|
|
||
| Note that this implementation will always read at least one row group, to get the list of columns and always | ||
| have the same schema, even if the requested rows are invalid (out of range). | ||
|
|
||
| This is the same as query() except that: | ||
|
|
||
| If binary columns are present, then: | ||
| - it computes a maximum size to allocate to binary data in step "parquet_index_with_metadata.row_groups_size_check_truncated_binary" | ||
| - it uses `read_truncated_binary()` in step "parquet_index_with_metadata.query_truncated_binary". | ||
|
|
||
|
|
@@ -219,27 +228,19 @@ def query_truncated_binary(self, offset: int, length: int) -> tuple[pa.Table, li | |
| `pa.Table`: The requested rows. | ||
| `list[strl]: List of truncated columns. | ||
| """ | ||
| all_columns = set(self.features) | ||
| binary_columns = set(column for column, feature in self.features.items() if feature == Value("binary")) | ||
| if not binary_columns: | ||
| return self.query(offset=offset, length=length), [] | ||
| with StepProfiler( | ||
| method="parquet_index_with_metadata.query", step="get the parquet files than contain the requested rows" | ||
| ): | ||
| parquet_file_offsets = np.cumsum(self.num_rows) | ||
|
|
||
| last_row_in_parquet = parquet_file_offsets[-1] - 1 | ||
| last_row_in_parquet = self.file_offsets[-1] - 1 | ||
| first_row = min(offset, last_row_in_parquet) | ||
| last_row = min(offset + length - 1, last_row_in_parquet) | ||
| first_parquet_file_id, last_parquet_file_id = np.searchsorted( | ||
| parquet_file_offsets, [first_row, last_row], side="right" | ||
| self.file_offsets, [first_row, last_row], side="right" | ||
| ) | ||
| parquet_offset = ( | ||
| offset - parquet_file_offsets[first_parquet_file_id - 1] if first_parquet_file_id > 0 else offset | ||
| offset - self.file_offsets[first_parquet_file_id - 1] if first_parquet_file_id > 0 else offset | ||
| ) | ||
| urls = self.parquet_files_urls[first_parquet_file_id : last_parquet_file_id + 1] # noqa: E203 | ||
| metadata_paths = self.metadata_paths[first_parquet_file_id : last_parquet_file_id + 1] # noqa: E203 | ||
| num_bytes = self.num_bytes[first_parquet_file_id : last_parquet_file_id + 1] # noqa: E203 | ||
| files_to_scan = self.files[first_parquet_file_id : last_parquet_file_id + 1] # noqa: E203 | ||
|
|
||
| with StepProfiler( | ||
| method="parquet_index_with_metadata.query", step="load the remote parquet files using metadata from disk" | ||
|
|
@@ -248,17 +249,17 @@ def query_truncated_binary(self, offset: int, length: int) -> tuple[pa.Table, li | |
| pq.ParquetFile( | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pq.ParquetFile can accept a |
||
| HTTPFile( | ||
| self.httpfs, | ||
| url, | ||
| f["url"], | ||
| session=self.httpfs_session, | ||
| size=size, | ||
| size=f["size"], | ||
| loop=self.httpfs.loop, | ||
| cache_type=None, | ||
| **self.httpfs.kwargs, | ||
| ), | ||
| metadata=pq.read_metadata(metadata_path), | ||
| metadata=pq.read_metadata(self.metadata_dir / f["parquet_metadata_subpath"]), | ||
| pre_buffer=True, | ||
| ) | ||
| for url, metadata_path, size in zip(urls, metadata_paths, num_bytes) | ||
| for f in files_to_scan | ||
| ] | ||
|
|
||
| with StepProfiler( | ||
|
|
@@ -272,7 +273,7 @@ def query_truncated_binary(self, offset: int, length: int) -> tuple[pa.Table, li | |
| ] | ||
| ) | ||
| row_group_readers = [ | ||
| RowGroupReader(parquet_file=parquet_file, group_id=group_id, features=self.features) | ||
| RowGroupReader(parquet_file=parquet_file, group_id=group_id, schema=self.features.arrow_schema) | ||
| for parquet_file in parquet_files | ||
| for group_id in range(parquet_file.metadata.num_row_groups) | ||
| ] | ||
|
|
@@ -290,6 +291,28 @@ def query_truncated_binary(self, offset: int, length: int) -> tuple[pa.Table, li | |
| row_group_offsets, [first_row, last_row], side="right" | ||
| ) | ||
|
|
||
| all_columns = set(self.features) | ||
| binary_columns = set(column for column, feature in self.features.items() if feature == Value("binary")) | ||
| if binary_columns: | ||
| pa_table, truncated_columns = self._read_with_binary( | ||
| row_group_readers, first_row_group_id, last_row_group_id, all_columns, binary_columns | ||
| ) | ||
| else: | ||
| pa_table, truncated_columns = self._read_without_binary( | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Theoretically we could remove this branch as well, at most there would be negligible performance overhead, but trying to keep the old behavior. |
||
| row_group_readers, first_row_group_id, last_row_group_id | ||
| ) | ||
|
|
||
| first_row_in_pa_table = row_group_offsets[first_row_group_id - 1] if first_row_group_id > 0 else 0 | ||
| return pa_table.slice(parquet_offset - first_row_in_pa_table, length), truncated_columns | ||
|
|
||
| def _read_with_binary( | ||
| self, | ||
| row_group_readers: list[RowGroupReader], | ||
| first_row_group_id: int, | ||
| last_row_group_id: int, | ||
| all_columns: set[str], | ||
| binary_columns: set[str], | ||
| ) -> tuple[pa.Table, list[str]]: | ||
| with StepProfiler( | ||
| method="parquet_index_with_metadata.row_groups_size_check_truncated_binary", | ||
| step="check if the rows can fit in memory", | ||
|
|
@@ -329,100 +352,21 @@ def query_truncated_binary(self, offset: int, length: int) -> tuple[pa.Table, li | |
| columns = list(self.features.keys()) | ||
| truncated_columns: set[str] = set() | ||
| for i in range(first_row_group_id, last_row_group_id + 1): | ||
| rg_pa_table, rg_truncated_columns = row_group_readers[i].read_truncated_binary( | ||
| columns, max_binary_length=max_binary_length | ||
| rg_pa_table = row_group_readers[i].read(columns) | ||
| rg_pa_table, rg_truncated_columns = truncate_binary_columns( | ||
| rg_pa_table, max_binary_length, self.features | ||
| ) | ||
| pa_tables.append(rg_pa_table) | ||
| truncated_columns |= set(rg_truncated_columns) | ||
| pa_table = pa.concat_tables(pa_tables) | ||
| except ArrowInvalid as err: | ||
| raise SchemaMismatchError("Parquet files have different schema.", err) | ||
| first_row_in_pa_table = row_group_offsets[first_row_group_id - 1] if first_row_group_id > 0 else 0 | ||
| return pa_table.slice(parquet_offset - first_row_in_pa_table, length), list(truncated_columns) | ||
|
|
||
| def query(self, offset: int, length: int) -> pa.Table: | ||
| """Query the parquet files | ||
|
|
||
| Note that this implementation will always read at least one row group, to get the list of columns and always | ||
| have the same schema, even if the requested rows are invalid (out of range). | ||
|
|
||
| Args: | ||
| offset (`int`): The first row to read. | ||
| length (`int`): The number of rows to read. | ||
|
|
||
| Raises: | ||
| [`TooBigRows`]: if the arrow data from the parquet row groups is bigger than max_arrow_data_in_memory | ||
|
|
||
| Returns: | ||
| `pa.Table`: The requested rows. | ||
| """ | ||
| with StepProfiler( | ||
| method="parquet_index_with_metadata.query", step="get the parquet files than contain the requested rows" | ||
| ): | ||
| parquet_file_offsets = np.cumsum(self.num_rows) | ||
|
|
||
| last_row_in_parquet = parquet_file_offsets[-1] - 1 | ||
| first_row = min(offset, last_row_in_parquet) | ||
| last_row = min(offset + length - 1, last_row_in_parquet) | ||
| first_parquet_file_id, last_parquet_file_id = np.searchsorted( | ||
| parquet_file_offsets, [first_row, last_row], side="right" | ||
| ) | ||
| parquet_offset = ( | ||
| offset - parquet_file_offsets[first_parquet_file_id - 1] if first_parquet_file_id > 0 else offset | ||
| ) | ||
| urls = self.parquet_files_urls[first_parquet_file_id : last_parquet_file_id + 1] # noqa: E203 | ||
| metadata_paths = self.metadata_paths[first_parquet_file_id : last_parquet_file_id + 1] # noqa: E203 | ||
| num_bytes = self.num_bytes[first_parquet_file_id : last_parquet_file_id + 1] # noqa: E203 | ||
|
|
||
| with StepProfiler( | ||
| method="parquet_index_with_metadata.query", step="load the remote parquet files using metadata from disk" | ||
| ): | ||
| parquet_files = [ | ||
| pq.ParquetFile( | ||
| HTTPFile( | ||
| self.httpfs, | ||
| url, | ||
| session=self.httpfs_session, | ||
| size=size, | ||
| loop=self.httpfs.loop, | ||
| cache_type=None, | ||
| **self.httpfs.kwargs, | ||
| ), | ||
| metadata=pq.read_metadata(metadata_path), | ||
| pre_buffer=True, | ||
| ) | ||
| for url, metadata_path, size in zip(urls, metadata_paths, num_bytes) | ||
| ] | ||
|
|
||
| with StepProfiler( | ||
| method="parquet_index_with_metadata.query", step="get the row groups than contain the requested rows" | ||
| ): | ||
| row_group_offsets = np.cumsum( | ||
| [ | ||
| parquet_file.metadata.row_group(group_id).num_rows | ||
| for parquet_file in parquet_files | ||
| for group_id in range(parquet_file.metadata.num_row_groups) | ||
| ] | ||
| ) | ||
| row_group_readers = [ | ||
| RowGroupReader(parquet_file=parquet_file, group_id=group_id, features=self.features) | ||
| for parquet_file in parquet_files | ||
| for group_id in range(parquet_file.metadata.num_row_groups) | ||
| ] | ||
|
|
||
| if len(row_group_offsets) == 0 or row_group_offsets[-1] == 0: # if the dataset is empty | ||
| if offset < 0: | ||
| raise IndexError("Offset must be non-negative") | ||
| return cast_table_to_schema(parquet_files[0].read(), self.features.arrow_schema) | ||
|
|
||
| last_row_in_parquet = row_group_offsets[-1] - 1 | ||
| first_row = min(parquet_offset, last_row_in_parquet) | ||
| last_row = min(parquet_offset + length - 1, last_row_in_parquet) | ||
|
|
||
| first_row_group_id, last_row_group_id = np.searchsorted( | ||
| row_group_offsets, [first_row, last_row], side="right" | ||
| ) | ||
| return pa_table, list(truncated_columns) | ||
|
|
||
| def _read_without_binary( | ||
| self, row_group_readers: list[RowGroupReader], first_row_group_id: int, last_row_group_id: int | ||
| ) -> tuple[pa.Table, list[str]]: | ||
| with StepProfiler( | ||
| method="parquet_index_with_metadata.row_groups_size_check", step="check if the rows can fit in memory" | ||
| ): | ||
|
|
@@ -443,8 +387,8 @@ def query(self, offset: int, length: int) -> pa.Table: | |
| ) | ||
| except ArrowInvalid as err: | ||
| raise SchemaMismatchError("Parquet files have different schema.", err) | ||
| first_row_in_pa_table = row_group_offsets[first_row_group_id - 1] if first_row_group_id > 0 else 0 | ||
| return pa_table.slice(parquet_offset - first_row_in_pa_table, length) | ||
|
|
||
| return pa_table, [] | ||
|
|
||
| @staticmethod | ||
| def from_parquet_metadata_items( | ||
|
|
@@ -458,40 +402,31 @@ def from_parquet_metadata_items( | |
| raise EmptyParquetMetadataError("No parquet files found.") | ||
|
|
||
| partial = parquet_export_is_partial(parquet_file_metadata_items[0]["url"]) | ||
| metadata_dir = Path(parquet_metadata_directory) | ||
|
|
||
| with StepProfiler( | ||
| method="parquet_index_with_metadata.from_parquet_metadata_items", | ||
| step="get the index from parquet metadata", | ||
| ): | ||
| try: | ||
| parquet_files_metadata = sorted( | ||
| parquet_file_metadata_items, key=lambda parquet_file_metadata: parquet_file_metadata["filename"] | ||
| ) | ||
| parquet_files_urls = [parquet_file_metadata["url"] for parquet_file_metadata in parquet_files_metadata] | ||
| metadata_paths = [ | ||
| os.path.join(parquet_metadata_directory, parquet_file_metadata["parquet_metadata_subpath"]) | ||
| for parquet_file_metadata in parquet_files_metadata | ||
| ] | ||
| num_bytes = [parquet_file_metadata["size"] for parquet_file_metadata in parquet_files_metadata] | ||
| num_rows = [parquet_file_metadata["num_rows"] for parquet_file_metadata in parquet_files_metadata] | ||
| files = sorted(parquet_file_metadata_items, key=lambda f: f["filename"]) | ||
| except Exception as e: | ||
| raise ParquetResponseFormatError(f"Could not parse the list of parquet files: {e}") from e | ||
|
|
||
| with StepProfiler( | ||
| method="parquet_index_with_metadata.from_parquet_metadata_items", step="get the dataset's features" | ||
| ): | ||
| if features is None: # config-parquet version<6 didn't have features | ||
| features = Features.from_arrow_schema(pq.read_schema(metadata_paths[0])) | ||
| first_arrow_schema = pq.read_schema(metadata_dir / files[0]["parquet_metadata_subpath"]) | ||
| features = Features.from_arrow_schema(first_arrow_schema) | ||
|
|
||
| return ParquetIndexWithMetadata( | ||
| files=files, | ||
| features=features, | ||
| parquet_files_urls=parquet_files_urls, | ||
| metadata_paths=metadata_paths, | ||
| num_bytes=num_bytes, | ||
| num_rows=num_rows, | ||
| httpfs=httpfs, | ||
| max_arrow_data_in_memory=max_arrow_data_in_memory, | ||
| partial=partial, | ||
| metadata_dir=metadata_dir, | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -551,28 +486,7 @@ def _init_parquet_index( | |
|
|
||
| # note that this cache size is global for the class, not per instance | ||
| @lru_cache(maxsize=1) | ||
| def query(self, offset: int, length: int) -> pa.Table: | ||
| """Query the parquet files | ||
|
|
||
| Note that this implementation will always read at least one row group, to get the list of columns and always | ||
| have the same schema, even if the requested rows are invalid (out of range). | ||
|
|
||
| Args: | ||
| offset (`int`): The first row to read. | ||
| length (`int`): The number of rows to read. | ||
|
|
||
| Returns: | ||
| `pa.Table`: The requested rows. | ||
| """ | ||
| logging.info( | ||
| f"Query {type(self.parquet_index).__name__} for dataset={self.dataset}, config={self.config}," | ||
| f" split={self.split}, offset={offset}, length={length}" | ||
| ) | ||
| return self.parquet_index.query(offset=offset, length=length) | ||
|
|
||
| # note that this cache size is global for the class, not per instance | ||
| @lru_cache(maxsize=1) | ||
| def query_truncated_binary(self, offset: int, length: int) -> tuple[pa.Table, list[str]]: | ||
| def query(self, offset: int, length: int) -> tuple[pa.Table, list[str]]: | ||
| """Query the parquet files | ||
|
|
||
| Note that this implementation will always read at least one row group, to get the list of columns and always | ||
|
|
@@ -590,4 +504,4 @@ def query_truncated_binary(self, offset: int, length: int) -> tuple[pa.Table, li | |
| f"Query {type(self.parquet_index).__name__} for dataset={self.dataset}, config={self.config}," | ||
| f" split={self.split}, offset={offset}, length={length}, with truncated binary" | ||
| ) | ||
| return self.parquet_index.query_truncated_binary(offset=offset, length=length) | ||
| return self.parquet_index.query(offset=offset, length=length) | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I applied this change in the libviewer PR as well but pulling this out here to make the reviewer process easier.