-
Notifications
You must be signed in to change notification settings - Fork 77
Draft: Add @distribute decorator for parallelizing functions
#123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
**Motivation.** Although EvoTorch has had a functional API for a while, it did not have easy-to-use parallelization capabilities for distributing computations across multiple devices. When writing an evolutionary search code using the functional API, the only way to evaluate solutions in parallel across multiple devices was to combine the functional-paradigm code with the object-oriented-paradigm code (by instantiating a `Problem` object with multiple actors and then transforming that `Problem` object to an functional evaluator using the method `make_callable_evaluator`. This approach was both limiting (i.e. only parallelized solution evaluation in mind) and was cumbersome (forcing the programmer to mix the usages of two APIs). **The newly introduced feature.** This commit introduces a general-purpose decorator `evotorch.decorators.distribute` which can take a function and transform it into its parallelized counterpart. Like a Problem object, the `distribute` decorator can be configured in terms of number of actors (`num_actors`) and number of GPUs visible to each actor (`num_gpus_per_actor`). Alternatively, an explicit list of devices can be given (e.g. `devices=["cuda:0", "cuda:1"]`). **How does it work?** Upon being called for the first time, the parallelized function will create ray actors. When it receives its arguments, the parallelized function will split those arguments into chunks, and each chunk will be sent to an actor, the wrapped function will be applied on the chunks remotely in parallel on the devices associated with those actors, and then finally the result will be collected and combined.
The decorators `on_device`, `on_cuda`, and `on_aux_device` are re-designed so that they now do functional transformation (i.e. they actually do the moving of the tensors to the requested device), instead of merely marking the decorated functions. This is compatible with the newly introduced `distribute` decorator, which also does functional transformation instead of just marking the decorated function.
WalkthroughAdds distributed chunking/stacking for tensors, TensorFrame, ObjectArray and containers; shallow-container device movers and device-counting helpers; and expanded decorators (vectorized, on_device, on_aux_device, on_cuda, distribute) enabling chunked, device-aware, actor-backed execution. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant Wrapper as "@distribute / DecoratorForDistributingFunctions"
participant Splitter as "split_into_chunks"
participant Scheduler as "Dispatcher / Actors"
participant Worker as "Remote Actor (i)"
participant Gatherer as "stack_chunks"
User->>Wrapper: call distributed_fn(*args, **kwargs)
Wrapper->>Splitter: split_arguments_into_chunks(args, split_arguments, num_actors, chunk_size)
Splitter-->>Wrapper: chunks_per_actor (args_i, kwargs_i) for i=1..N
Wrapper->>Scheduler: dispatch chunks to actors (parallel)
Scheduler->>Worker: actor_i: run user function with (args_i, kwargs_i)
Worker-->>Scheduler: result_i
Scheduler-->>Wrapper: collect [result_1..result_N]
Wrapper->>Gatherer: stack_chunks([result_1..result_N])
Gatherer-->>User: reassembled_result
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (2)
🧰 Additional context used🧬 Code graph analysis (2)src/evotorch/_distribute.py (4)
tests/test_decorators.py (3)
🪛 Ruff (0.14.0)src/evotorch/_distribute.py73-73: Avoid specifying long messages outside the exception class (TRY003) 76-76: Avoid specifying long messages outside the exception class (TRY003) 86-86: Avoid specifying long messages outside the exception class (TRY003) 91-91: Avoid specifying long messages outside the exception class (TRY003) 102-105: Avoid specifying long messages outside the exception class (TRY003) 123-123: Avoid specifying long messages outside the exception class (TRY003) 175-178: Avoid specifying long messages outside the exception class (TRY003) 247-247: Avoid specifying long messages outside the exception class (TRY003) 251-254: Avoid specifying long messages outside the exception class (TRY003) 322-322: Avoid specifying long messages outside the exception class (TRY003) 334-334: Avoid specifying long messages outside the exception class (TRY003) 379-379: Avoid specifying long messages outside the exception class (TRY003) 381-381: Avoid specifying long messages outside the exception class (TRY003) 385-385: Avoid specifying long messages outside the exception class (TRY003) 389-389: Avoid specifying long messages outside the exception class (TRY003) 401-401: Avoid specifying long messages outside the exception class (TRY003) 479-479: Avoid specifying long messages outside the exception class (TRY003) 508-508: Avoid specifying long messages outside the exception class (TRY003) 510-510: Avoid specifying long messages outside the exception class (TRY003) 516-516: Avoid specifying long messages outside the exception class (TRY003) 518-518: Avoid specifying long messages outside the exception class (TRY003) 527-527: Avoid specifying long messages outside the exception class (TRY003) 546-546: Avoid specifying long messages outside the exception class (TRY003) 573-573: Avoid specifying long messages outside the exception class (TRY003) 575-575: Avoid specifying long messages outside the exception class (TRY003) 596-596: Avoid specifying long messages outside the exception class (TRY003) 598-598: Avoid specifying long messages outside the exception class (TRY003) 625-625: Avoid specifying long messages outside the exception class (TRY003) 627-627: Avoid specifying long messages outside the exception class (TRY003) 648-648: Avoid specifying long messages outside the exception class (TRY003) 650-650: Avoid specifying long messages outside the exception class (TRY003) 657-657: Avoid specifying long messages outside the exception class (TRY003) 699-702: Avoid specifying long messages outside the exception class (TRY003) 751-751: Avoid specifying long messages outside the exception class (TRY003) 799-799: Avoid specifying long messages outside the exception class (TRY003) 802-802: Avoid specifying long messages outside the exception class (TRY003) 804-804: Avoid specifying long messages outside the exception class (TRY003) 812-812: Avoid specifying long messages outside the exception class (TRY003) 816-816: Mutable class attributes should be annotated with (RUF012) 929-934: Avoid specifying long messages outside the exception class (TRY003) 946-949: Avoid specifying long messages outside the exception class (TRY003) 971-971: Avoid specifying long messages outside the exception class (TRY003) 976-976: Add explicit value for parameter (B905) 1014-1014: Avoid specifying long messages outside the exception class (TRY003) 1016-1016: Avoid specifying long messages outside the exception class (TRY003) 1025-1027: Avoid specifying long messages outside the exception class (TRY003) 1109-1111: Avoid specifying long messages outside the exception class (TRY003) 1110-1110: Use explicit conversion flag Replace with conversion flag (RUF010) 1114-1114: Avoid specifying long messages outside the exception class (TRY003) 1128-1128: Avoid specifying long messages outside the exception class (TRY003) 1133-1137: Avoid specifying long messages outside the exception class (TRY003) 1136-1136: Use explicit conversion flag Replace with conversion flag (RUF010) 1141-1141: Avoid specifying long messages outside the exception class (TRY003) 1145-1145: Avoid specifying long messages outside the exception class (TRY003) 1149-1153: Avoid specifying long messages outside the exception class (TRY003) 1152-1152: Use explicit conversion flag Replace with conversion flag (RUF010) 1161-1165: Avoid specifying long messages outside the exception class (TRY003) 1164-1164: Use explicit conversion flag Replace with conversion flag (RUF010) tests/test_decorators.py198-198: Unused function argument: (ARG001) ⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
🔇 Additional comments (23)
Comment |
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #123 +/- ##
==========================================
- Coverage 75.43% 75.03% -0.41%
==========================================
Files 59 61 +2
Lines 9556 10111 +555
==========================================
+ Hits 7209 7587 +378
- Misses 2347 2524 +177 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🧹 Nitpick comments (6)
tests/test_decorators.py (1)
185-199: Fix: usechunk_sizeparameter intest_distribute
chunk_sizeis parametrized but unused (Ruff ARG001). Pass it todistribute(...)when not None to exercise chunking.@@ -@pytest.mark.parametrize( +@pytest.mark.parametrize( "decoration_form, distribute_config, chunk_size", @@ def test_distribute(decoration_form: bool, distribute_config: dict, chunk_size: int | None): @@ - if decoration_form: + if decoration_form: - @distribute(**distribute_config) + @distribute(**({**distribute_config, **({"chunk_size": chunk_size} if chunk_size is not None else {})})) def distributed_f(x: torch.Tensor) -> torch.Tensor: return f(x) else: - distributed_f = distribute(f, **distribute_config) + distributed_f = distribute( + f, **({**distribute_config, **({"chunk_size": chunk_size} if chunk_size is not None else {})}) + )Also applies to: 198-228
src/evotorch/_distribute.py (3)
447-460: Docstring mismatch in_all_are_non_scalarsDocstring says “Return True if all tensors are scalars” but function returns True when none are scalars.
Update docstring to reflect behavior (True if all tensors are non-scalars).
976-986: Harden zip withstrict=TrueGuard against silent truncation if lengths diverge.
- for split_arg, arg in zip(self._iter_split_arguments(args), args): + for split_arg, arg in zip(self._iter_split_arguments(args), args, strict=True):
101-111: Relax overly strictchunk_sizecheckDisallowing
chunk_size >= tensor_sizeforbids a valid single-chunk case. Prefer allowing equality (1 chunk).- if chunk_size >= tensor_size: + if chunk_size > tensor_size: raise ValueError( "Cannot split the tensor into chunks because the given chunk size" - " is larger than or equal to the original tensor size." + " is larger than the original tensor size." )src/evotorch/decorators.py (2)
1231-1238: Type hint:devicesshould accept device specs, not booleans
distribute(..., devices: Sequence[bool] | None = None)is incorrect. It should beSequence[torch.device | str] | None.-def distribute( +def distribute( *arguments, num_actors: str | int | None = None, chunk_size: int | None = None, num_gpus_per_actor: int | float | str | None = None, - devices: Sequence[bool] | None = None, + devices: Sequence[torch.device | str] | None = None, ) -> Callable:
276-283: Remove unusednoqa
# noqa: C901is flagged unused (RUF100). Consider removing.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
src/evotorch/_distribute.py(1 hunks)src/evotorch/decorators.py(6 hunks)src/evotorch/tools/_shallow_containers.py(1 hunks)src/evotorch/tools/tensorframe.py(1 hunks)tests/test_decorators.py(6 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
src/evotorch/tools/_shallow_containers.py (2)
src/evotorch/tools/objectarray.py (1)
ObjectArray(39-534)src/evotorch/tools/tensorframe.py (7)
TensorFrame(53-1175)device(482-498)has_enforced_device(421-426)is_read_only(759-763)columns(501-505)to(457-467)without_enforced_device(428-438)
tests/test_decorators.py (1)
src/evotorch/decorators.py (10)
distribute(1231-1485)on_aux_device(629-748)on_cuda(751-866)on_device(276-626)pass_info(171-209)vectorized(212-273)decorator(135-137)decorator(515-624)decorator(738-743)decorator(1223-1226)
src/evotorch/_distribute.py (4)
src/evotorch/core.py (9)
Problem(365-3410)SolutionBatch(3590-4600)num_actors(2136-2142)cat(4581-4600)is_remote(2161-2165)actors(2145-2151)is_main(2168-2173)actor_index(2154-2158)aux_device(1657-1692)src/evotorch/tools/objectarray.py (1)
ObjectArray(39-534)src/evotorch/tools/tensorframe.py (7)
TensorFrame(53-1175)device(482-498)to(457-467)as_tensor(304-360)is_read_only(759-763)get_read_only_view(765-769)vstack(930-959)src/evotorch/tools/_shallow_containers.py (1)
move_shallow_container_to_device(104-159)
src/evotorch/decorators.py (3)
src/evotorch/core.py (5)
device(1648-1654)device(4485-4489)device(4969-4973)Problem(365-3410)num_actors(2136-2142)src/evotorch/_distribute.py (4)
_loosely_find_leftmost_dimension_size(771-812)split_arguments_into_chunks(339-426)stack_chunks(668-702)DecoratorForDistributingFunctions(1191-1221)src/evotorch/tools/_shallow_containers.py (2)
most_favored_device_among_arguments(236-281)move_shallow_container_to_device(104-159)
🪛 Ruff (0.14.0)
src/evotorch/tools/_shallow_containers.py
135-135: Avoid specifying long messages outside the exception class
(TRY003)
142-145: Avoid specifying long messages outside the exception class
(TRY003)
152-155: Avoid specifying long messages outside the exception class
(TRY003)
157-157: Avoid specifying long messages outside the exception class
(TRY003)
213-213: Avoid specifying long messages outside the exception class
(TRY003)
215-215: Avoid specifying long messages outside the exception class
(TRY003)
217-217: Avoid specifying long messages outside the exception class
(TRY003)
231-231: Avoid specifying long messages outside the exception class
(TRY003)
261-261: Avoid specifying long messages outside the exception class
(TRY003)
263-263: Avoid specifying long messages outside the exception class
(TRY003)
tests/test_decorators.py
198-198: Unused function argument: chunk_size
(ARG001)
src/evotorch/_distribute.py
73-73: Avoid specifying long messages outside the exception class
(TRY003)
76-76: Avoid specifying long messages outside the exception class
(TRY003)
86-86: Avoid specifying long messages outside the exception class
(TRY003)
91-91: Avoid specifying long messages outside the exception class
(TRY003)
102-105: Avoid specifying long messages outside the exception class
(TRY003)
123-123: Avoid specifying long messages outside the exception class
(TRY003)
175-178: Avoid specifying long messages outside the exception class
(TRY003)
247-247: Avoid specifying long messages outside the exception class
(TRY003)
251-254: Avoid specifying long messages outside the exception class
(TRY003)
322-322: Avoid specifying long messages outside the exception class
(TRY003)
334-334: Avoid specifying long messages outside the exception class
(TRY003)
379-379: Avoid specifying long messages outside the exception class
(TRY003)
381-381: Avoid specifying long messages outside the exception class
(TRY003)
385-385: Avoid specifying long messages outside the exception class
(TRY003)
389-389: Avoid specifying long messages outside the exception class
(TRY003)
401-401: Avoid specifying long messages outside the exception class
(TRY003)
479-479: Avoid specifying long messages outside the exception class
(TRY003)
508-508: Avoid specifying long messages outside the exception class
(TRY003)
510-510: Avoid specifying long messages outside the exception class
(TRY003)
516-516: Avoid specifying long messages outside the exception class
(TRY003)
518-518: Avoid specifying long messages outside the exception class
(TRY003)
527-527: Avoid specifying long messages outside the exception class
(TRY003)
546-546: Avoid specifying long messages outside the exception class
(TRY003)
573-573: Avoid specifying long messages outside the exception class
(TRY003)
575-575: Avoid specifying long messages outside the exception class
(TRY003)
596-596: Avoid specifying long messages outside the exception class
(TRY003)
598-598: Avoid specifying long messages outside the exception class
(TRY003)
625-625: Avoid specifying long messages outside the exception class
(TRY003)
627-627: Avoid specifying long messages outside the exception class
(TRY003)
648-648: Avoid specifying long messages outside the exception class
(TRY003)
650-650: Avoid specifying long messages outside the exception class
(TRY003)
657-657: Avoid specifying long messages outside the exception class
(TRY003)
699-702: Avoid specifying long messages outside the exception class
(TRY003)
751-751: Avoid specifying long messages outside the exception class
(TRY003)
799-799: Avoid specifying long messages outside the exception class
(TRY003)
802-802: Avoid specifying long messages outside the exception class
(TRY003)
804-804: Avoid specifying long messages outside the exception class
(TRY003)
812-812: Avoid specifying long messages outside the exception class
(TRY003)
816-816: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
929-934: Avoid specifying long messages outside the exception class
(TRY003)
946-949: Avoid specifying long messages outside the exception class
(TRY003)
971-971: Avoid specifying long messages outside the exception class
(TRY003)
976-976: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
1014-1014: Avoid specifying long messages outside the exception class
(TRY003)
1016-1016: Avoid specifying long messages outside the exception class
(TRY003)
1025-1027: Avoid specifying long messages outside the exception class
(TRY003)
1109-1111: Avoid specifying long messages outside the exception class
(TRY003)
1110-1110: Use explicit conversion flag
Replace with conversion flag
(RUF010)
1114-1114: Avoid specifying long messages outside the exception class
(TRY003)
1128-1128: Avoid specifying long messages outside the exception class
(TRY003)
1133-1137: Avoid specifying long messages outside the exception class
(TRY003)
1136-1136: Use explicit conversion flag
Replace with conversion flag
(RUF010)
1141-1141: Avoid specifying long messages outside the exception class
(TRY003)
1145-1145: Avoid specifying long messages outside the exception class
(TRY003)
1149-1153: Avoid specifying long messages outside the exception class
(TRY003)
1152-1152: Use explicit conversion flag
Replace with conversion flag
(RUF010)
1161-1165: Avoid specifying long messages outside the exception class
(TRY003)
1164-1164: Use explicit conversion flag
Replace with conversion flag
(RUF010)
src/evotorch/decorators.py
276-276: Unused noqa directive (non-enabled: C901)
Remove unused noqa directive
(RUF100)
481-491: Avoid specifying long messages outside the exception class
(TRY003)
523-526: Avoid specifying long messages outside the exception class
(TRY003)
529-537: Avoid specifying long messages outside the exception class
(TRY003)
565-567: Avoid specifying long messages outside the exception class
(TRY003)
734-734: Avoid specifying long messages outside the exception class
(TRY003)
859-859: Avoid specifying long messages outside the exception class
(TRY003)
954-957: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: test
🔇 Additional comments (3)
src/evotorch/tools/tensorframe.py (1)
420-427: Property addition looks goodAccurately reflects enforced-device state and integrates with callers (e.g., _shallow_containers._move_tensorframe).
src/evotorch/tools/_shallow_containers.py (2)
32-72: LGTM on TensorFrame move semanticsRespects enforced-device state and move_only_from_cpu; consistent with TensorFrame API.
186-234: LGTM: device counting for shallow containersAccurately handles tensors, TensorFrame columns, ObjectArray, and rejects nested containers.
| # We are being extra careful here for ensuring that `split_arguments` is a sequence of booleans. | ||
| # We want to actively prevent unexpected behavior that could be caused by these mistakes: | ||
| # - providing argument indices instead of a sequence of booleans | ||
| # - providing one or more argument names as strings, instead of a sequence of booleans | ||
| _actual_split_arguments = [] | ||
| for split_arg in split_arguments: | ||
| if isinstance(split_arg, torch.Tensor) and (split_arg.ndim == 0): | ||
| _actual_split_arguments.append(bool(split_arg.to(device="cpu"))) | ||
| elif isinstance(split_arg, (bool, np.bool_)): | ||
| _actual_split_arguments.append(bool(split_arg)) | ||
| else: | ||
| raise TypeError("`split_arguments` was expected to contain booleans only") | ||
| split_arguments = tuple(_actual_split_arguments) | ||
|
|
||
| if devices is None: | ||
| if (num_actors is None) or (num_actors <= 1): | ||
| raise ValueError( | ||
| "The argument `devices` was received as None." | ||
| " When `devices` is None, `num_actors` is expected as an integer that is at least 2." | ||
| f" However, the given value of `num_actors` is {repr(num_actors)}." | ||
| ) | ||
| devices = tuple() | ||
| else: | ||
| if isinstance(devices, tuple) and hasattr(devices, "_fields"): | ||
| raise ValueError("`devices` in the form of a named tuple is not supported") | ||
| devices = tuple(torch.device(item) for item in devices) | ||
| num_devices = len(devices) | ||
| if num_devices == 0: | ||
| raise ValueError("`devices` cannot be given as an empty sequence") | ||
| if num_actors is None: | ||
| num_actors = num_devices | ||
| else: | ||
| raise ValueError( | ||
| "The `argument` devices was received as provided as a value other than None." | ||
| " When `devices` is not None, `num_actors` is expected to be left as None." | ||
| f" However, it was received as {repr(num_actors)}." | ||
| ) | ||
|
|
||
| # We are given an explicit sequence of devices. | ||
| # Therefore, we assume that the actors must be able to see all the accelerator devices, | ||
| # and therefore override `num_gpus_per_actor` as "all". | ||
| if num_gpus_per_actor is None: | ||
| num_gpus_per_actor = "all" | ||
| else: | ||
| raise ValueError( | ||
| "The `argument` devices was received as provided as a value other than None." | ||
| " When `devices` is not None, `num_gpus_per_actor` is expected to be left as None." | ||
| f" However, it was received as {repr(num_gpus_per_actor)}." | ||
| ) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Handle num_actors when given as string; fix devices typing
Comparing num_actors <= 1 fails when num_actors is a string (e.g., "num_gpus"). Also, devices is typed as Sequence[torch.device | str] here but DecoratorForDistributingFunctions uses Sequence[bool].
- if devices is None:
- if (num_actors is None) or (num_actors <= 1):
+ if devices is None:
+ if num_actors is None:
raise ValueError(
"The argument `devices` was received as None."
" When `devices` is None, `num_actors` is expected as an integer that is at least 2."
- f" However, the given value of `num_actors` is {repr(num_actors)}."
+ f" However, the given value of `num_actors` is {num_actors!r}."
)
+ if isinstance(num_actors, int) and (num_actors <= 1):
+ raise ValueError("`num_actors` must be at least 2 when `devices` is None")
devices = tuple()
@@
- devices = tuple(torch.device(item) for item in devices)
+ devices = tuple(torch.device(item) for item in devices)Also update DecoratorForDistributingFunctions.__init__ to correct the type:
- devices: Sequence[bool] | None = None,
+ devices: Sequence[torch.device | str] | None = None,Committable suggestion skipped: line range outside the PR's diff.
🧰 Tools
🪛 Ruff (0.14.0)
1128-1128: Avoid specifying long messages outside the exception class
(TRY003)
1133-1137: Avoid specifying long messages outside the exception class
(TRY003)
1136-1136: Use explicit conversion flag
Replace with conversion flag
(RUF010)
1141-1141: Avoid specifying long messages outside the exception class
(TRY003)
1145-1145: Avoid specifying long messages outside the exception class
(TRY003)
1149-1153: Avoid specifying long messages outside the exception class
(TRY003)
1152-1152: Use explicit conversion flag
Replace with conversion flag
(RUF010)
1161-1165: Avoid specifying long messages outside the exception class
(TRY003)
1164-1164: Use explicit conversion flag
Replace with conversion flag
(RUF010)
🤖 Prompt for AI Agents
In src/evotorch/_distribute.py around lines 1117-1166, validate num_actors
before numeric comparison and fix type mismatches: explicitly check that
num_actors is an integer (e.g., isinstance(num_actors, (int, np.integer)))
before applying `<= 1` and raise a clear TypeError if it is a non-integer (e.g.,
a string like "num_gpus"); keep the existing behavior that when devices is None
num_actors must be an int >= 2. Also update the typing/usage consistency for
devices and split_arguments: ensure the function/type annotations accept devices
as Sequence[torch.device | str] (you already convert each item with
torch.device(item)) and ensure DecoratorForDistributingFunctions.__init__
signature is corrected so split_arguments is typed as Sequence[bool] and devices
as Sequence[torch.device | str] to match this implementation.
| def expects_ndim_decorator(fn: Callable): | ||
| if hasattr(fn, "__evotorch_distribute__"): | ||
| raise ValueError( | ||
| "Cannot apply `@expects_ndim` or `@rowwise` on a function" | ||
| " that was previously subjected to `@distribute`" | ||
| ) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: expects_ndim uses wrong variable name; causes NameError
In tensor_info.update(), the condition references expected_arg_ndims (plural), which is undefined. Should use expected_arg_ndim.
- for expected_arg_ndim, arg in zip(expected_ndims, args):
- if (expected_arg_ndims is not None) and isinstance(arg, torch.Tensor):
+ for expected_arg_ndim, arg in zip(expected_ndims, args):
+ if (expected_arg_ndim is not None) and isinstance(arg, torch.Tensor):Also applies to: 1026-1048
🧰 Tools
🪛 Ruff (0.14.0)
954-957: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In src/evotorch/decorators.py around lines 952 to 958 (and similarly at lines
1026 to 1048), the call to tensor_info.update() references the undefined
variable expected_arg_ndims (plural) causing a NameError; change those
references to the correct singular variable expected_arg_ndim so the condition
uses the defined value, and run tests to confirm no other occurrences of the
plural name remain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (4)
src/evotorch/_distribute.py (4)
529-537: Bug: incorrect concatenation for ObjectArray chunks
list(chain(chunks))yields a list of ObjectArray objects, not their elements. This stacks arrays-of-arrays.Use
chain.from_iterable(orchain(*chunks)) to flatten:- resulting_stack = as_tensor(list(chain(chunks)), dtype=object) + resulting_stack = as_tensor(list(chain.from_iterable(chunks)), dtype=object)
1086-1089: Bug: wrong attribute access for__evotorch_pass_info__Accessing
self.wrap_info.function__evotorch_pass_info__raises AttributeError. Should read from the wrapped function:if hasattr(self.wrap_info.function, "__evotorch_pass_info__"): - self.__evotorch_pass_info__ = self.wrap_info.function__evotorch_pass_info__ + self.__evotorch_pass_info__ = self.wrap_info.function.__evotorch_pass_info__
1131-1137: Handlenum_actorswhen given as string.Comparing
num_actors <= 1fails whennum_actorsis a string (e.g.,"num_gpus"). Validate type before numeric comparison:if devices is None: - if (num_actors is None) or (num_actors <= 1): + if num_actors is None: raise ValueError( "The argument `devices` was received as None." " When `devices` is None, `num_actors` is expected as an integer that is at least 2." - f" However, the given value of `num_actors` is {repr(num_actors)}." + f" However, the given value of `num_actors` is {num_actors!r}." ) + if isinstance(num_actors, int) and (num_actors <= 1): + raise ValueError("`num_actors` must be at least 2 when `devices` is None") devices = tuple()
1205-1205: Fix incorrect type annotation fordevices.The
devicesparameter is typed asSequence[bool]but should beSequence[torch.device | str]to match the implementation:- devices: Sequence[bool] | None = None, + devices: Sequence[torch.device | str] | None = None,
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/evotorch/_distribute.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/evotorch/_distribute.py (4)
src/evotorch/core.py (9)
Problem(365-3410)SolutionBatch(3590-4600)num_actors(2136-2142)cat(4581-4600)is_remote(2161-2165)actors(2145-2151)is_main(2168-2173)actor_index(2154-2158)aux_device(1657-1692)src/evotorch/tools/objectarray.py (1)
ObjectArray(39-534)src/evotorch/tools/tensorframe.py (7)
TensorFrame(53-1175)device(482-498)to(457-467)pick(508-524)is_read_only(759-763)get_read_only_view(765-769)vstack(930-959)src/evotorch/tools/_shallow_containers.py (1)
move_shallow_container_to_device(104-159)
🪛 Ruff (0.14.0)
src/evotorch/_distribute.py
73-73: Avoid specifying long messages outside the exception class
(TRY003)
76-76: Avoid specifying long messages outside the exception class
(TRY003)
86-86: Avoid specifying long messages outside the exception class
(TRY003)
91-91: Avoid specifying long messages outside the exception class
(TRY003)
102-105: Avoid specifying long messages outside the exception class
(TRY003)
123-123: Avoid specifying long messages outside the exception class
(TRY003)
175-178: Avoid specifying long messages outside the exception class
(TRY003)
247-247: Avoid specifying long messages outside the exception class
(TRY003)
251-254: Avoid specifying long messages outside the exception class
(TRY003)
322-322: Avoid specifying long messages outside the exception class
(TRY003)
334-334: Avoid specifying long messages outside the exception class
(TRY003)
379-379: Avoid specifying long messages outside the exception class
(TRY003)
381-381: Avoid specifying long messages outside the exception class
(TRY003)
385-385: Avoid specifying long messages outside the exception class
(TRY003)
389-389: Avoid specifying long messages outside the exception class
(TRY003)
401-401: Avoid specifying long messages outside the exception class
(TRY003)
479-479: Avoid specifying long messages outside the exception class
(TRY003)
508-508: Avoid specifying long messages outside the exception class
(TRY003)
510-510: Avoid specifying long messages outside the exception class
(TRY003)
516-516: Avoid specifying long messages outside the exception class
(TRY003)
518-518: Avoid specifying long messages outside the exception class
(TRY003)
527-527: Avoid specifying long messages outside the exception class
(TRY003)
546-546: Avoid specifying long messages outside the exception class
(TRY003)
573-573: Avoid specifying long messages outside the exception class
(TRY003)
575-575: Avoid specifying long messages outside the exception class
(TRY003)
596-596: Avoid specifying long messages outside the exception class
(TRY003)
598-598: Avoid specifying long messages outside the exception class
(TRY003)
625-625: Avoid specifying long messages outside the exception class
(TRY003)
627-627: Avoid specifying long messages outside the exception class
(TRY003)
648-648: Avoid specifying long messages outside the exception class
(TRY003)
650-650: Avoid specifying long messages outside the exception class
(TRY003)
657-657: Avoid specifying long messages outside the exception class
(TRY003)
699-702: Avoid specifying long messages outside the exception class
(TRY003)
751-751: Avoid specifying long messages outside the exception class
(TRY003)
799-799: Avoid specifying long messages outside the exception class
(TRY003)
802-802: Avoid specifying long messages outside the exception class
(TRY003)
804-804: Avoid specifying long messages outside the exception class
(TRY003)
812-812: Avoid specifying long messages outside the exception class
(TRY003)
816-816: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
929-934: Avoid specifying long messages outside the exception class
(TRY003)
946-949: Avoid specifying long messages outside the exception class
(TRY003)
971-971: Avoid specifying long messages outside the exception class
(TRY003)
976-976: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
1014-1014: Avoid specifying long messages outside the exception class
(TRY003)
1016-1016: Avoid specifying long messages outside the exception class
(TRY003)
1025-1027: Avoid specifying long messages outside the exception class
(TRY003)
1109-1111: Avoid specifying long messages outside the exception class
(TRY003)
1110-1110: Use explicit conversion flag
Replace with conversion flag
(RUF010)
1114-1114: Avoid specifying long messages outside the exception class
(TRY003)
1128-1128: Avoid specifying long messages outside the exception class
(TRY003)
1133-1137: Avoid specifying long messages outside the exception class
(TRY003)
1136-1136: Use explicit conversion flag
Replace with conversion flag
(RUF010)
1141-1141: Avoid specifying long messages outside the exception class
(TRY003)
1145-1145: Avoid specifying long messages outside the exception class
(TRY003)
1149-1153: Avoid specifying long messages outside the exception class
(TRY003)
1152-1152: Use explicit conversion flag
Replace with conversion flag
(RUF010)
1161-1165: Avoid specifying long messages outside the exception class
(TRY003)
1164-1164: Use explicit conversion flag
Replace with conversion flag
(RUF010)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: test
| num_explicit_devices = len(self.__devices) | ||
| prepared_args = [] | ||
|
|
||
| for split_arg, arg in zip(self._iter_split_arguments(args), args): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Add strict=True to zip() for safety.
The zip() call should explicitly set strict=True to catch length mismatches between split_arguments and args:
- for split_arg, arg in zip(self._iter_split_arguments(args), args):
+ for split_arg, arg in zip(self._iter_split_arguments(args), args, strict=True):📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| for split_arg, arg in zip(self._iter_split_arguments(args), args): | |
| for split_arg, arg in zip(self._iter_split_arguments(args), args, strict=True): |
🧰 Tools
🪛 Ruff (0.14.0)
976-976: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
🤖 Prompt for AI Agents
In src/evotorch/_distribute.py around line 976, the
zip(self._iter_split_arguments(args), args) call should be made strict to catch
length mismatches; change the call to zip(self._iter_split_arguments(args),
args, strict=True) (ensure project runs on Python 3.10+ where strict is
supported) and run tests to verify no unintended mismatches occur.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (5)
src/evotorch/decorators.py (2)
1054-1055: Critical: Loop variable name should be singular (second occurrence).Similar to line 976, the loop variable
expected_arg_ndimsshould beexpected_arg_ndim(singular) since it represents a single item.Apply this diff:
- for i_arg, (expected_arg_ndims, arg) in enumerate(zip(expected_ndims, args)): - if (expected_arg_ndims is None) or isinstance(arg, torch.Tensor): + for i_arg, (expected_arg_ndim, arg) in enumerate(zip(expected_ndims, args)): + if (expected_arg_ndim is None) or isinstance(arg, torch.Tensor):Also update all subsequent references to use the singular form on lines 1074, etc.
976-977: Critical: Loop variable name should be singular.The loop variable is named
expected_arg_ndims(plural), but it represents a single item fromexpected_ndims. Both the loop variable declaration and its usage should use the singular formexpected_arg_ndim.Apply this diff:
- for expected_arg_ndims, arg in zip(expected_ndims, args): - if (expected_arg_ndims is not None) and isinstance(arg, torch.Tensor): + for expected_arg_ndim, arg in zip(expected_ndims, args): + if (expected_arg_ndim is not None) and isinstance(arg, torch.Tensor):src/evotorch/_distribute.py (3)
528-537: Critical: Incorrect ObjectArray concatenation.
list(chain(chunks))produces a list of ObjectArray objects rather than flattening their elements, resulting in an array-of-arrays instead of a flat array.Apply this diff:
- resulting_stack = as_tensor(list(chain(chunks)), dtype=object) + resulting_stack = as_tensor(list(chain.from_iterable(chunks)), dtype=object)
976-976: Addstrict=Trueto zip for safety.Without
strict=True, length mismatches betweenself._iter_split_arguments(args)andargswould silently truncate, potentially causing subtle bugs.Apply this diff:
- for split_arg, arg in zip(self._iter_split_arguments(args), args): + for split_arg, arg in zip(self._iter_split_arguments(args), args, strict=True):
1205-1205: Critical: Wrong type annotation fordevicesparameter.The
devicesparameter is typed asSequence[bool]but should beSequence[torch.device | str]to match its usage in_prepare_distributed_function(line 1142) where items are converted viatorch.device(item).Apply this diff:
- devices: Sequence[bool] | None = None, + devices: Sequence[torch.device | str] | None = None,
🧹 Nitpick comments (4)
src/evotorch/_distribute.py (4)
531-537: Verify read-only ObjectArray handling.The code preserves read-only status by checking all chunks (lines 531-537). This is correct, but consider if a single read-only chunk should make the result read-only (current behavior) or if all chunks must be read-only. Document the intended semantics.
816-817: Document thread-safety of global cache.The
_Wrappedclass uses a global dictionary and lock for caching distributed functions. While the lock protects access, consider documenting that this cache is process-global and persists across multiple decorator invocations, which may have memory implications for long-running processes.
1177-1187: Cache invalidation strategy needed.The global cache in
_Wrapped.functionsgrows unbounded as different functions are wrapped. For long-running processes that dynamically create many distributed functions, this could lead to memory leaks. Consider adding a cache size limit or eviction policy.Do you want me to suggest a simple LRU cache implementation?
89-110: Add tests for tensor‐splitting edge cases.Cover when
tensor_size < num_actors(expect one‐element chunks) and whenchunk_size >= tensor_size(expect aValueError).
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/evotorch/_distribute.py(1 hunks)src/evotorch/decorators.py(6 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/evotorch/_distribute.py (4)
src/evotorch/core.py (10)
Problem(365-3410)SolutionBatch(3590-4600)num_actors(2136-2142)cat(4581-4600)is_remote(2161-2165)evaluate(2532-2571)actors(2145-2151)is_main(2168-2173)actor_index(2154-2158)aux_device(1657-1692)src/evotorch/tools/objectarray.py (1)
ObjectArray(39-534)src/evotorch/tools/tensorframe.py (7)
TensorFrame(53-1175)device(482-498)to(457-467)as_tensor(304-360)is_read_only(759-763)get_read_only_view(765-769)vstack(930-959)src/evotorch/tools/_shallow_containers.py (1)
move_shallow_container_to_device(104-159)
src/evotorch/decorators.py (2)
src/evotorch/_distribute.py (4)
_loosely_find_leftmost_dimension_size(771-812)split_arguments_into_chunks(339-426)stack_chunks(668-702)DecoratorForDistributingFunctions(1191-1221)src/evotorch/tools/_shallow_containers.py (2)
most_favored_device_among_arguments(236-281)move_shallow_container_to_device(104-159)
🪛 Ruff (0.14.0)
src/evotorch/_distribute.py
73-73: Avoid specifying long messages outside the exception class
(TRY003)
76-76: Avoid specifying long messages outside the exception class
(TRY003)
86-86: Avoid specifying long messages outside the exception class
(TRY003)
91-91: Avoid specifying long messages outside the exception class
(TRY003)
102-105: Avoid specifying long messages outside the exception class
(TRY003)
123-123: Avoid specifying long messages outside the exception class
(TRY003)
175-178: Avoid specifying long messages outside the exception class
(TRY003)
247-247: Avoid specifying long messages outside the exception class
(TRY003)
251-254: Avoid specifying long messages outside the exception class
(TRY003)
322-322: Avoid specifying long messages outside the exception class
(TRY003)
334-334: Avoid specifying long messages outside the exception class
(TRY003)
379-379: Avoid specifying long messages outside the exception class
(TRY003)
381-381: Avoid specifying long messages outside the exception class
(TRY003)
385-385: Avoid specifying long messages outside the exception class
(TRY003)
389-389: Avoid specifying long messages outside the exception class
(TRY003)
401-401: Avoid specifying long messages outside the exception class
(TRY003)
479-479: Avoid specifying long messages outside the exception class
(TRY003)
508-508: Avoid specifying long messages outside the exception class
(TRY003)
510-510: Avoid specifying long messages outside the exception class
(TRY003)
516-516: Avoid specifying long messages outside the exception class
(TRY003)
518-518: Avoid specifying long messages outside the exception class
(TRY003)
527-527: Avoid specifying long messages outside the exception class
(TRY003)
546-546: Avoid specifying long messages outside the exception class
(TRY003)
573-573: Avoid specifying long messages outside the exception class
(TRY003)
575-575: Avoid specifying long messages outside the exception class
(TRY003)
596-596: Avoid specifying long messages outside the exception class
(TRY003)
598-598: Avoid specifying long messages outside the exception class
(TRY003)
625-625: Avoid specifying long messages outside the exception class
(TRY003)
627-627: Avoid specifying long messages outside the exception class
(TRY003)
648-648: Avoid specifying long messages outside the exception class
(TRY003)
650-650: Avoid specifying long messages outside the exception class
(TRY003)
657-657: Avoid specifying long messages outside the exception class
(TRY003)
699-702: Avoid specifying long messages outside the exception class
(TRY003)
751-751: Avoid specifying long messages outside the exception class
(TRY003)
799-799: Avoid specifying long messages outside the exception class
(TRY003)
802-802: Avoid specifying long messages outside the exception class
(TRY003)
804-804: Avoid specifying long messages outside the exception class
(TRY003)
812-812: Avoid specifying long messages outside the exception class
(TRY003)
816-816: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
929-934: Avoid specifying long messages outside the exception class
(TRY003)
946-949: Avoid specifying long messages outside the exception class
(TRY003)
971-971: Avoid specifying long messages outside the exception class
(TRY003)
976-976: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
1014-1014: Avoid specifying long messages outside the exception class
(TRY003)
1016-1016: Avoid specifying long messages outside the exception class
(TRY003)
1025-1027: Avoid specifying long messages outside the exception class
(TRY003)
1109-1111: Avoid specifying long messages outside the exception class
(TRY003)
1110-1110: Use explicit conversion flag
Replace with conversion flag
(RUF010)
1114-1114: Avoid specifying long messages outside the exception class
(TRY003)
1128-1128: Avoid specifying long messages outside the exception class
(TRY003)
1133-1137: Avoid specifying long messages outside the exception class
(TRY003)
1136-1136: Use explicit conversion flag
Replace with conversion flag
(RUF010)
1141-1141: Avoid specifying long messages outside the exception class
(TRY003)
1145-1145: Avoid specifying long messages outside the exception class
(TRY003)
1149-1153: Avoid specifying long messages outside the exception class
(TRY003)
1152-1152: Use explicit conversion flag
Replace with conversion flag
(RUF010)
1161-1165: Avoid specifying long messages outside the exception class
(TRY003)
1164-1164: Use explicit conversion flag
Replace with conversion flag
(RUF010)
src/evotorch/decorators.py
276-276: Unused noqa directive (non-enabled: C901)
Remove unused noqa directive
(RUF100)
481-491: Avoid specifying long messages outside the exception class
(TRY003)
523-526: Avoid specifying long messages outside the exception class
(TRY003)
529-537: Avoid specifying long messages outside the exception class
(TRY003)
565-567: Avoid specifying long messages outside the exception class
(TRY003)
734-734: Avoid specifying long messages outside the exception class
(TRY003)
859-859: Avoid specifying long messages outside the exception class
(TRY003)
954-957: Avoid specifying long messages outside the exception class
(TRY003)
976-976: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: test
🔇 Additional comments (4)
src/evotorch/decorators.py (2)
1231-1485: LGTM: Well-documenteddistributedecorator.The decorator provides a clean interface for distributed execution with comprehensive documentation and examples. The delegation to
DecoratorForDistributingFunctionsfollows a good separation of concerns.
276-502: Add tests for inlineon_devicetransformationNo existing tests cover the inline forms (
on_device(fn, device=…)andon_device(fn, tuple_of_booleans, device=…)). Add unit tests for:
- Valid cases: function-only and function+tuple selectors.
- Invalid cases: wrong tuple types/lengths, empty tuples to confirm TypeError is raised.
src/evotorch/_distribute.py (2)
1131-1137: LGTM: Proper validation fornum_actors.The validation now correctly handles string values (like
"num_gpus") by checkingisinstance(num_actors, Integral)before numeric comparison, preventing TypeErrors.
911-927: Potential race condition in actor creation.While the lock on line 919 protects
__parallelized, the check on line 928 forself.actorshappens outside the lock. If multiple threads call this method simultaneously, they might all pass line 920 before any completes actor creation, leading to redundant evaluations.Consider moving the final check inside the lock:
with self.__parallelization_lock: if not self.__parallelized: dummy_batch = SolutionBatch(self, popsize=1) self.evaluate(dummy_batch) self.__parallelized = True + + if (self.actors is None) or (len(self.actors) < 2): + raise RuntimeError( + "Failed to create the distributed counterpart of the original function." + " Hint: this can happen if the arguments given to the `@distribute` decorator imply a non-distributed" + " environment, e.g., if one sets `num_actors='num_gpus'` when one has only 1 GPU," + " or if one sets `num_actors` as an integer that is smaller than 2." + ) - if (self.actors is None) or (len(self.actors) < 2): - raise RuntimeError( - "Failed to create the distributed counterpart of the original function." - " Hint: this can happen if the arguments given to the `@distribute` decorator imply a non-distributed" - " environment, e.g., if one sets `num_actors='num_gpus'` when one has only 1 GPU," - " or if one sets `num_actors` as an integer that is smaller than 2." - )Likely an incorrect or invalid review comment.
Motivation.
Although EvoTorch has had a functional API for a while, this functional API did not have easy-to-use parallelization capabilities for distributing computations across multiple devices. When writing an evolutionary search code using the functional API, the only way to evaluate solutions in parallel across multiple devices was to combine the functional-paradigm code with the object-oriented-paradigm code (by instantiating a
Problemobject with multiple actors and then transforming thatProblemobject to a functional evaluator using the methodmake_callable_evaluator). This approach was limiting (i.e. only parallelized solution evaluation in mind) and was cumbersome (forcing the programmer to mix the usages of two APIs).The newly introduced feature.
This commit introduces a general-purpose decorator,
evotorch.decorators.distribute, which can take a function and transform it into its parallelized counterpart. Like aProblemobject, thedistributedecorator can be configured in terms of number of actors (num_actors) and number of GPUs visible to each actor (num_gpus_per_actor). Alternatively, an explicit list of devices can be given (e.g.devices=["cuda:0", "cuda:1"]).How does it work?
Upon being called for the first time, the parallelized function will create ray actors. When it receives its arguments, the parallelized function follows these steps:
Example.
A distributed function might look like this:
We assume that the function expects
xandyto have the same leftmost dimension size, and that the function will return a resulting tensor with the same leftmost dimension size. Once called, this example distributed function will splitxandyinto two (along their leftmost dimensions), send the first halves to the first actor (the one usingcuda:0) and the second halves to the second actor (the one usingcuda:1), wait for the parallel computation, and collect and concatenate the resulting chunks into a single resulting tensor.A distributed function can work with following input arguments and result types:
torch.Tensorevotorch.tools.ReadOnlyTensorevotorch.tools.TensorFrameevotorch.tools.ObjectArrayReadOnlyTensors and/orTensorFrames and/orObjectArraysReadOnlyTensors and/orTensorFrames and/orObjectArraysSummary by CodeRabbit
New Features
Tests