Skip to content

fix: replace deprecated ray parallelism arg with override_num_blocks #2876

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions awswrangler/distributed/ray/modin/s3/_read_orc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def _read_orc_distributed(
schema: pa.schema | None,
columns: list[str] | None,
use_threads: bool | int,
parallelism: int,
override_num_blocks: int,
version_ids: dict[str, str] | None,
s3_client: "S3Client" | None,
s3_additional_kwargs: dict[str, Any] | None,
Expand All @@ -43,7 +43,7 @@ def _read_orc_distributed(
)
ray_dataset = read_datasource(
datasource,
parallelism=parallelism,
override_num_blocks=override_num_blocks,
)
to_pandas_kwargs = _data_types.pyarrow2pandas_defaults(
use_threads=use_threads,
Expand Down
4 changes: 2 additions & 2 deletions awswrangler/distributed/ray/modin/s3/_read_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _read_parquet_distributed(
columns: list[str] | None,
coerce_int96_timestamp_unit: str | None,
use_threads: bool | int,
parallelism: int,
override_num_blocks: int,
version_ids: dict[str, str] | None,
s3_client: "S3Client" | None,
s3_additional_kwargs: dict[str, Any] | None,
Expand All @@ -60,7 +60,7 @@ def _read_parquet_distributed(
"dataset_kwargs": dataset_kwargs,
},
),
parallelism=parallelism,
override_num_blocks=override_num_blocks,
)
return _to_modin(
dataset=dataset,
Expand Down
4 changes: 2 additions & 2 deletions awswrangler/distributed/ray/modin/s3/_read_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def _read_text_distributed(
s3_additional_kwargs: dict[str, str] | None,
dataset: bool,
ignore_index: bool,
parallelism: int,
override_num_blocks: int,
version_ids: dict[str, str] | None,
pandas_kwargs: dict[str, Any],
) -> pd.DataFrame:
Expand Down Expand Up @@ -172,6 +172,6 @@ def _read_text_distributed(
meta_provider=FastFileMetadataProvider(),
**configuration,
),
parallelism=parallelism,
override_num_blocks=override_num_blocks,
)
return _to_modin(dataset=ray_dataset, ignore_index=ignore_index)
18 changes: 18 additions & 0 deletions awswrangler/s3/_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from awswrangler.catalog._utils import _catalog_id
from awswrangler.distributed.ray import ray_get
from awswrangler.s3._list import _path2list, _prefix_cleanup
from awswrangler.typing import RaySettings

if TYPE_CHECKING:
from mypy_boto3_glue.type_defs import GetTableResponseTypeDef
Expand Down Expand Up @@ -377,3 +378,20 @@ def _get_paths_for_glue_table(
)

return paths, path_root, res


def _get_num_output_blocks(
ray_args: RaySettings | None = None,
) -> int:
ray_args = ray_args or {}
parallelism = ray_args.get("parallelism", -1)
override_num_blocks = ray_args.get("override_num_blocks")
if parallelism != -1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't it raise an error instead of a warning. override_num_blocks does not accept -1 or does it? If not then it would raise an error too

Copy link
Contributor Author

@kukushking kukushking Jun 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No - this is backwards-compatibility fix, same as in the ray code. If parallelism is passed and it's != -1, it will be routed to override_num_blocks. We will remove the warning and stop supporting parallelism arg in the next major release.

pass
_logger.warning(
"The argument ``parallelism`` is deprecated and will be removed in the next major release. "
"Please specify ``override_num_blocks`` instead."
)
elif override_num_blocks is not None:
parallelism = override_num_blocks
return parallelism
7 changes: 3 additions & 4 deletions awswrangler/s3/_read_orc.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
_apply_partition_filter,
_check_version_id,
_extract_partitions_dtypes_from_table_details,
_get_num_output_blocks,
_get_path_ignore_suffix,
_get_path_root,
_get_paths_for_glue_table,
Expand Down Expand Up @@ -137,7 +138,7 @@ def _read_orc(
schema: pa.schema | None,
columns: list[str] | None,
use_threads: bool | int,
parallelism: int,
override_num_blocks: int,
version_ids: dict[str, str] | None,
s3_client: "S3Client" | None,
s3_additional_kwargs: dict[str, Any] | None,
Expand Down Expand Up @@ -283,8 +284,6 @@ def read_orc(
>>> df = wr.s3.read_orc(path, dataset=True, partition_filter=my_filter)

"""
ray_args = ray_args if ray_args else {}

s3_client = _utils.client(service_name="s3", session=boto3_session)
paths: list[str] = _path2list(
path=path,
Expand Down Expand Up @@ -330,7 +329,7 @@ def read_orc(
schema=schema,
columns=columns,
use_threads=use_threads,
parallelism=ray_args.get("parallelism", -1),
override_num_blocks=_get_num_output_blocks(ray_args),
s3_client=s3_client,
s3_additional_kwargs=s3_additional_kwargs,
arrow_kwargs=arrow_kwargs,
Expand Down
5 changes: 3 additions & 2 deletions awswrangler/s3/_read_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
_apply_partition_filter,
_check_version_id,
_extract_partitions_dtypes_from_table_details,
_get_num_output_blocks,
_get_path_ignore_suffix,
_get_path_root,
_get_paths_for_glue_table,
Expand Down Expand Up @@ -285,7 +286,7 @@ def _read_parquet(
columns: list[str] | None,
coerce_int96_timestamp_unit: str | None,
use_threads: bool | int,
parallelism: int,
override_num_blocks: int,
version_ids: dict[str, str] | None,
s3_client: "S3Client" | None,
s3_additional_kwargs: dict[str, Any] | None,
Expand Down Expand Up @@ -562,7 +563,7 @@ def read_parquet(
columns=columns,
coerce_int96_timestamp_unit=coerce_int96_timestamp_unit,
use_threads=use_threads,
parallelism=ray_args.get("parallelism", -1),
override_num_blocks=_get_num_output_blocks(ray_args),
s3_client=s3_client,
s3_additional_kwargs=s3_additional_kwargs,
arrow_kwargs=arrow_kwargs,
Expand Down
2 changes: 1 addition & 1 deletion awswrangler/s3/_read_parquet.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _read_parquet(
columns: list[str] | None,
coerce_int96_timestamp_unit: str | None,
use_threads: bool | int,
parallelism: int,
override_num_blocks: int,
version_ids: dict[str, str] | None,
s3_client: "S3Client" | None,
s3_additional_kwargs: dict[str, Any] | None,
Expand Down
6 changes: 3 additions & 3 deletions awswrangler/s3/_read_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from awswrangler.s3._read import (
_apply_partition_filter,
_check_version_id,
_get_num_output_blocks,
_get_path_ignore_suffix,
_get_path_root,
_union,
Expand Down Expand Up @@ -52,7 +53,7 @@ def _read_text(
s3_additional_kwargs: dict[str, str] | None,
dataset: bool,
ignore_index: bool,
parallelism: int,
override_num_blocks: int,
version_ids: dict[str, str] | None,
pandas_kwargs: dict[str, Any],
) -> pd.DataFrame:
Expand Down Expand Up @@ -131,7 +132,6 @@ def _read_text_format(
**args,
)

ray_args = ray_args if ray_args else {}
return _read_text(
read_format,
paths=paths,
Expand All @@ -141,7 +141,7 @@ def _read_text_format(
s3_additional_kwargs=s3_additional_kwargs,
dataset=dataset,
ignore_index=ignore_index,
parallelism=ray_args.get("parallelism", -1),
override_num_blocks=_get_num_output_blocks(ray_args),
version_ids=version_ids,
pandas_kwargs=pandas_kwargs,
)
Expand Down
2 changes: 1 addition & 1 deletion awswrangler/s3/_read_text.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def _read_text(
s3_additional_kwargs: dict[str, str] | None,
dataset: bool,
ignore_index: bool,
parallelism: int,
override_num_blocks: int,
version_ids: dict[str, str] | None,
pandas_kwargs: dict[str, Any],
) -> pd.DataFrame | Iterator[pd.DataFrame]: ...
Expand Down
7 changes: 7 additions & 0 deletions awswrangler/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,13 @@ class RaySettings(TypedDict):
Parallelism may be limited by the number of files of the dataset.
Auto-detect by default.
"""
override_num_blocks: NotRequired[int]
"""
Override the number of output blocks from all read tasks.
By default, the number of output blocks is dynamically decided based on
input data size and available resources. You shouldn't manually set this
value in most cases.
"""


class RayReadParquetSettings(RaySettings):
Expand Down
2 changes: 1 addition & 1 deletion tests/glue_scripts/ray_read_small_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
import awswrangler as wr

paths = wr.s3.list_objects(f"s3://{os.environ['data-gen-bucket']}/parquet/small/partitioned/")
ray.data.read_parquet_bulk(paths=paths, parallelism=1000).to_modin()
ray.data.read_parquet_bulk(paths=paths, override_num_blocks=1000).to_modin()
2 changes: 1 addition & 1 deletion tests/glue_scripts/wrangler_read_small_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@

wr.s3.read_parquet(
path=f"s3://{os.environ['data-gen-bucket']}/parquet/small/partitioned/",
ray_args={"parallelism": 1000, "bulk_read": True},
ray_args={"override_num_blocks": 1000, "bulk_read": True},
)
2 changes: 1 addition & 1 deletion tests/glue_scripts/wrangler_write_partitioned_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

df = wr.s3.read_parquet(
path=f"s3://{os.environ['data-gen-bucket']}/parquet/medium/partitioned/",
ray_args={"parallelism": 1000},
ray_args={"override_num_blocks": 1000},
)

wr.s3.to_parquet(
Expand Down
Loading