diff --git a/src/evotorch/_distribute.py b/src/evotorch/_distribute.py new file mode 100644 index 0000000..d715420 --- /dev/null +++ b/src/evotorch/_distribute.py @@ -0,0 +1,1221 @@ +# Copyright 2025 NNAISENSE SA +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from collections.abc import Callable, Iterable, Mapping, Sequence +from itertools import chain +from numbers import Integral +from threading import Lock +from typing import Any, NamedTuple + +import numpy as np +import torch +from ray.util import ActorPool + +from .core import Problem, SolutionBatch +from .tools import ObjectArray, TensorFrame +from .tools._shallow_containers import move_shallow_container_to_device + +TensorLike = torch.Tensor | TensorFrame | ObjectArray + + +class _TensorSplittingResult(NamedTuple): + chunks: list[TensorLike] + original_size: int + + +def _split_tensor( + x: TensorLike, + num_actors: int, + *, + chunk_size: int | None = None, + expect_size: int | None = None, + target_device: str | torch.device | None = None, +) -> _TensorSplittingResult: + """ + Split a tensor, or a TensorFrame, or an ObjectArray into chunks. + + Args: + x: The tensor or ObjectArray or TensorFrame to be split into chunks. + num_actors: Number of remote actors, as an integer that is at least 2. + If `chunk_size` is not provided (i.e. left as None), the number + of chunks and the size of each chunk will be determined by this + `num_actors`. + chunk_size: The size of a chunk when splitting a tensor/array into + chunks. If this is provided, then this will be the main factor for + determining the chunk size and also the number of chunks. + Can be left as None if the number of chunks and the chunk size is + to be determined by `num_actors` instead. + expect_size: If provided, the leftmost dimension of the given tensor + or array or the number of rows of the given TensorFrame will be + compared to this given number. If the size of `x` does not match + `expect_size`, an error will be raised. + target_device: If provided, the chunks will be moved into this device + (except when `x` is an `ObjectArray` which will then stay on the + cpu regardless of the given `target_device`). + Returns: + A named tuple in which the attribute `chunks` stores the chunks in a + list, and `original_size` represents the size of `x` before the + splitting operation. + """ + if not isinstance(x, (torch.Tensor, TensorFrame, ObjectArray)): + raise TypeError(f"Expected a tensor or a TensorFrame or an ObjectArray, but got an instance of {type(x)}") + + if isinstance(x, torch.Tensor) and (x.ndim == 0): + raise ValueError("Cannot split a 0-dimensional tensor into chunks") + + if (target_device is not None) and isinstance(x, (torch.Tensor, TensorFrame)): + # If we are given a target device, we move the original `x` to that device, so that its chunks will be + # on that device as well. + x = x.to(device=target_device) + + tensor_size = len(x) + + if (expect_size is not None) and (tensor_size != expect_size): + raise ValueError("While trying to split tensors into chunks, encountered incompatible tensor sizes") + + # Compute the chunk sizes + if chunk_size is None: + if tensor_size == 0: + raise ValueError("Cannot split a tensor whose leftmost dimension size is 0") + elif tensor_size < num_actors: + chunk_sizes = [1 for _ in range(tensor_size)] + else: + min_chunk_size = tensor_size // num_actors + remaining = tensor_size % num_actors + chunk_sizes = [min_chunk_size for _ in range(num_actors)] + for i in range(remaining): + chunk_sizes[i] += 1 + else: + if chunk_size >= tensor_size: + raise ValueError( + "Cannot split the tensor into chunks because the given chunk size" + " is larger than or equal to the original tensor size." + ) + min_num_chunks = tensor_size // chunk_size + last_chunk_size = tensor_size % chunk_size + chunk_sizes = [chunk_size for _ in range(min_num_chunks)] + if last_chunk_size > 0: + chunk_sizes.append(last_chunk_size) + + # Prepare the chunks + chunks = [] + i = 0 + j = 0 + for chunk_size in chunk_sizes: + j = i + chunk_size + if isinstance(x, (torch.Tensor, ObjectArray)): + chunk = x[i:j] + elif isinstance(x, TensorFrame): + chunk = x.pick[i:j, :] + else: + raise TypeError("Execution should not have reached this point. This is most probably a bug.") + chunks.append(chunk) + i = j + + return _TensorSplittingResult(chunks=chunks, original_size=tensor_size) + + +class _DictSplittingResult(NamedTuple): + chunks: list[dict[Any, TensorLike]] + original_size: int + + +def _split_dict( + x: Mapping[Any, TensorLike], + num_actors: int, + *, + chunk_size: int | None = None, + expect_size: int | None = None, + target_device: str | torch.device | None = None, +) -> _DictSplittingResult: + """ + Split the tensors/`TensorFrame`s/`ObjectArray`s in a dictionary-like object. + + Args: + x: A shallow (non-nested) dictionary-like object (that is an instance + of `collections.abc.Mapping`) that contains tensors and/or + `TensorFrame`s and/or `ObjectArray`s. + num_actors: Number of remote actors, as an integer that is at least 2. + If `chunk_size` is not provided (i.e. left as None), the number + of chunks and the size of each chunk will be determined by this + `num_actors`. + chunk_size: The size of a chunk when splitting a tensor/array into + chunks. If this is provided, then this will be the main factor for + determining the chunk size and also the number of chunks. + Can be left as None if the number of chunks and the chunk size is + to be determined by `num_actors` instead. + expect_size: If provided, the leftmost dimension of the contained + tensors or `ObjectArray`s or the number of rows of the contained + `TensorFrame`s will be compared to this given number. + If the tensor/array sizes within `x` do not match `expect_size`, an + error will be raised. + target_device: If provided, the chunks will be moved into this device + (except when an `ObjectArray` is encountered which will be kept + on the cpu regardless of the given `target_device`). + Returns: + A named tuple in which the attribute `chunks` stores a list of + dictionaries (the values in each dictionary being the chunks of + tensors/arrays), `original_size` represents the original tensor/array + size within `x` before the splitting operation. + """ + + if len(x) == 0: + raise ValueError( + "Cannot split the tensor values into chunks within the given dictionary," + " because the given dictionary is empty" + ) + + dict_chunks: list[dict[Any, TensorLike]] | None = None + original_tensor_size: int | None = expect_size + num_chunks: int | None = None + for k, v in x.items(): + chunks, tensor_size = _split_tensor( + v, num_actors, chunk_size=chunk_size, expect_size=original_tensor_size, target_device=target_device + ) + + if original_tensor_size is None: + original_tensor_size = tensor_size + + if dict_chunks is None: + num_chunks = len(chunks) + dict_chunks = [{} for _ in range(num_chunks)] + + for i_chunk in range(num_chunks): + dict_chunks[i_chunk][k] = chunks[i_chunk] + + return _DictSplittingResult(chunks=dict_chunks, original_size=original_tensor_size) + + +class _SequenceSplittingResult(NamedTuple): + chunks: list[list[TensorLike] | tuple[TensorLike, ...]] + original_size: int + + +def _split_sequence( + x: Sequence[TensorLike], + num_actors: int, + *, + chunk_size: int | None = None, + expect_size: int | None = None, + target_device: str | torch.device | None = None, +) -> _SequenceSplittingResult: + """ + Split the tensors/`TensorFrame`s/`ObjectArray`s in a sequence. + + Args: + x: A shallow (non-nested) sequence (that is an instance of + `collections.abc.Sequence`) that contains tensors and/or + `TensorFrame`s and/or `ObjectArray`s. + num_actors: Number of remote actors, as an integer that is at least 2. + If `chunk_size` is not provided (i.e. left as None), the number + of chunks and the size of each chunk will be determined by this + `num_actors`. + chunk_size: The size of a chunk when splitting a tensor/array into + chunks. If this is provided, then this will be the main factor for + determining the chunk size and also the number of chunks. + Can be left as None if the number of chunks and the chunk size is + to be determined by `num_actors` instead. + expect_size: If provided, the leftmost dimension of the contained + tensors or `ObjectArray`s or the number of rows of the contained + `TensorFrame`s will be compared to this given number. + If the tensor/array sizes within `x` do not match `expect_size`, an + error will be raised. + target_device: If provided, the chunks will be moved into this device. + (except when an `ObjectArray` is encountered which will be kept + on the cpu regardless of the given `target_device`). + Returns: + A named tuple in which the attribute `chunks` is a list of sequences + (the items within each sequence being the chunks of tensors/arrays), + `original_size` represents the original tensor/array size within `x` + before the splitting operation. + """ + result_must_be_tuple = False + if isinstance(x, tuple): + if hasattr(x, "_fields"): + raise TypeError("Named tuples are not supported") + result_must_be_tuple = True + + if len(x) == 0: + raise ValueError( + "Cannot split the tensor values into chunks within the given sequence," + " because the given sequence is empty" + ) + + sequence_chunks: list[list[TensorLike]] | None = None + original_tensor_size: int | None = expect_size + num_chunks: int | None + for v in x: + chunks, tensor_size = _split_tensor( + v, num_actors, chunk_size=chunk_size, expect_size=original_tensor_size, target_device=target_device + ) + + if original_tensor_size is None: + original_tensor_size = tensor_size + + if sequence_chunks is None: + num_chunks = len(chunks) + sequence_chunks = [[] for _ in range(num_chunks)] + + for i_chunk in range(num_chunks): + sequence_chunks[i_chunk].append(chunks[i_chunk]) + + if result_must_be_tuple: + sequence_chunks = [tuple(item) for item in sequence_chunks] + + return _SequenceSplittingResult(chunks=sequence_chunks, original_size=original_tensor_size) + + +def split_into_chunks( + x: TensorLike | Sequence[TensorLike] | Mapping[Any, TensorLike], + num_actors: int, + *, + chunk_size: int | None = None, + expect_size: int | None = None, + target_device: str | torch.device | None = None, +) -> _TensorSplittingResult | _DictSplittingResult | _SequenceSplittingResult: + """ + Split into chunks a tensor/ObjectArray/TensorFrame or a container of them. + + Args: + x: A tensor, or a `TensorFrame`, or an `ObjectArray`, or a shallow + (non-nested) dictionary-like container or a sequence containing one + or more tensor/`TensorFrame`/`ObjectArray`. This is the input that + is subject to splitting into chunks. + num_actors: Number of remote actors, as an integer that is at least 2. + If `chunk_size` is not provided (i.e. left as None), the number + of chunks and the size of each chunk will be determined by this + `num_actors`. + chunk_size: The size of a chunk when splitting a tensor/array into + chunks. If this is provided, then this will be the main factor for + determining the chunk size and also the number of chunks. + Can be left as None if the number of chunks and the chunk size is + to be determined by `num_actors` instead. + expect_size: If provided, the leftmost dimension of the contained + tensors or `ObjectArray`s or the number of rows of the contained + `TensorFrame`s will be compared to this given number. + If the tensor/array sizes within `x` do not match `expect_size`, an + error will be raised. + target_device: If provided, the chunks will be moved into this device. + (except when an `ObjectArray` is encountered which will be kept + on the cpu regardless of the given `target_device`). + Returns: + A named tuple in which `chunks` is a list containing the chunks of `x`, + `original_size` is the original size of `x`. + """ + if expect_size is not None: + expect_size = int(expect_size) + if isinstance(x, (str, np.str_, bytes, bytearray)): + # Here, we actively prevent objects that are technically instances of collections.abc.Sequence + # but cannot contain any tensor/TensorFrame/ObjectArray + raise TypeError(f"Unsupported type: {type(x)}") + elif isinstance(x, (torch.Tensor, TensorFrame, ObjectArray)): + result = _split_tensor( + x, num_actors, chunk_size=chunk_size, expect_size=expect_size, target_device=target_device + ) + elif isinstance(x, Mapping): + result = _split_dict(x, num_actors, chunk_size=chunk_size, expect_size=expect_size, target_device=target_device) + elif isinstance(x, Sequence): + result = _split_sequence( + x, num_actors, chunk_size=chunk_size, expect_size=expect_size, target_device=target_device + ) + else: + raise TypeError(f"Unsupported type: {type(x)}") + + return result + + +def split_arguments_into_chunks( + args: Sequence, + split_arguments: Sequence[bool], + num_actors: int, + *, + chunk_size: int | None = None, + target_device: str | torch.device | None = None, +) -> list: + """ + Split the specified arguments within the given sequence into chunks. + + Splittable arguments are tensors, `TensorFrame`s, `ObjectArray`s, + or shallow dictionary-like containers or sequences consisting of + tensors and/or `TensorFrame`s and/or `ObjectArray`s. + + Args: + args: A sequence (e.g. list or tuple) of arguments. + split_arguments: A sequence of booleans. Within this sequence, + if the i-th element is True, then the i-th element of `args` + is subject to being split into chunks. On the other hand, + if the i-th element of `split_arguments` is False, then the + i-th element of `args` is going to be duplicated as it is, + instead of being split. + num_actors: Number of remote actors, as an integer that is at least 2. + If `chunk_size` is not provided (i.e. left as None), the number + of chunks and the size of each chunk will be determined by this + `num_actors`. + chunk_size: The size of a chunk when splitting a tensor/array into + chunks. If this is provided, then this will be the main factor for + determining the chunk size and also the number of chunks. + Can be left as None if the number of chunks and the chunk size is + to be determined by `num_actors` instead. + target_device: If provided, the chunks will be moved into this device. + (except when an `ObjectArray` is encountered which will be kept + on the cpu regardless of the given `target_device`). + Returns: + A list of argument chunks. This returned list has the same length + with `args`, but within it, each element is a list of chunks. + """ + if isinstance(args, tuple) and hasattr(args, "_fields"): + raise TypeError("`args` cannot be given in the form of a named tuple") + if not isinstance(args, Sequence): + raise TypeError(f"Expected `args` as a Sequence, but received it as an instance of {type(args)}") + if isinstance(args, (str, np.str_, bytes, bytearray)): + # Here, we actively prevent `args` from being given as instances of types that are technically + # sequences, but that cannot contain arguments. + raise TypeError(f"Received `args` as an instance of {type(args)}, which is not supported") + + num_args = len(args) + if len(split_arguments) != num_args: + raise TypeError(f"Expected {len(split_arguments)} positional arguments, but got {num_args}") + + # Understand which arguments are subject to splitting, and which arguments are subject to duplication + arg_indices_to_split = [] + arg_indices_to_duplicate = [] + for i_arg, split_arg in enumerate(split_arguments): + if split_arg: + arg_indices_to_split.append(i_arg) + else: + arg_indices_to_duplicate.append(i_arg) + + if len(arg_indices_to_split) == 0: + raise ValueError("None of the positional arguments were marked for being split into chunks") + + # The following list is to store chunks for each argument + result = [None for _ in range(num_args)] + + # Loop over the arguments to split first + original_size = None + num_chunks = None + for i_arg in arg_indices_to_split: + # Split the argument into chunks + chunks, tensor_size = split_into_chunks( + args[i_arg], num_actors, chunk_size=chunk_size, expect_size=original_size, target_device=target_device + ) + # Make sure that we know the original size and the number of chunks + if original_size is None: + original_size = tensor_size + num_chunks = len(chunks) + # Put the chunks + result[i_arg] = chunks + + # Now that we have split all the arguments that are marked for splitting and that we know what is the chunk size, + # we can now loop over the arguments that are marked for duplication. + for i_arg in arg_indices_to_duplicate: + result[i_arg] = [args[i_arg] for _ in range(num_chunks)] + + return result + + +def _all_are_instances(objects: Iterable, type_or_types: type | tuple[type, ...]) -> bool: + """ + Return True if all the given objects match the given type(s). + + Args: + objects: An iterable of objects whose types are being queried + type_or_types: A type or a tuple of types + Returns: + True if the types of the given `objects` match the provided + `type_or_types`; + False otherwise. + """ + for obj in objects: + if not isinstance(obj, type_or_types): + return False + return True + + +def _all_are_non_scalars(tensors: Iterable[torch.Tensor]) -> bool: + """ + Return True if `tensors` are all non-scalars (having 1 or more dimensions). + + Args: + tensors: An iterable of PyTorch tensors. + Returns: + True if all `tensors` are non-scalars; False otherwise. + """ + for t in tensors: + if t.ndim == 0: + return False + return True + + +def _ensure_chunk_lengths_are_valid(objects_with_length: Sequence, expected_lengths: Sequence[int] | None): + """ + Ensure that the lengths of the objects match the desired lengths. + + Args: + objects_with_length: A sequence of objects that have lengths (i.e. + that have the method `__len__`). + expected_lengths: A sequence of integers. The length of the i-th + element within `objects_with_length` must match the i-th integer + within `expected_lengths`. + Raises: + ValueError: if the lengths do not match. + """ + if expected_lengths is None: + return + for obj, expected_len in zip(objects_with_length, expected_lengths, strict=True): + if len(obj) != expected_len: + raise ValueError("Received a chunk with an unexpected size") + + +def _stack_chunked_tensors( + chunks: Sequence[TensorLike], *, expect_chunk_sizes: Sequence[int] | None = None +) -> TensorLike: + """ + Stack chunks of tensors/`TensorFrame`s/`ObjectArray`s. + + If `chunks` consists of tensors, those tensors will be concatenated along + their leftmost dimensions. + If `chunks` consists of `TensorFrame`s, those frames will be vertically + stacked. + If `chunks` consists of `ObjectArray`s, those arrays will be concatenated. + + Args: + chunks: A sequence consisting of tensors/`TensorFrame`s/`ObjectArray`s. + expect_chunk_sizes: If given, it will be expected that the size of the + i-th item within `chunks` matches the i-th integer within + `expect_chunk_sizes`. + Returns: + The combined tensor/TensorFrame/ObjectArray. + Raises: + ValueError: if the i-th chunk does not have the size specified by the + i-th integer within the provided `expect_chunk_sizes`. + """ + from .tools import as_tensor + + if not isinstance(chunks, Sequence): + raise TypeError(f"`chunks` was expected as a Sequence, but it was received as an instance of {type(chunks)}") + if isinstance(chunks, tuple) and hasattr(chunks, "_fields"): + raise TypeError("`chunks` in the form of a named tuple is not supported") + if not isinstance(chunks, list): + chunks = list(chunks) + + num_chunks = len(chunks) + if num_chunks == 0: + raise ValueError("Cannot operate on an empty list of chunks") + if (expect_chunk_sizes is not None) and (num_chunks != len(expect_chunk_sizes)): + raise ValueError("Received an unexpected number of chunks") + + resulting_stack: torch.Tensor | TensorFrame + + if _all_are_instances(chunks, torch.Tensor): + if _all_are_non_scalars(chunks): + _ensure_chunk_lengths_are_valid(chunks, expect_chunk_sizes) + resulting_stack = torch.cat(chunks) + else: + raise ValueError("Received a chunk in the form of a scalar tensor, which is unexpected") + elif _all_are_instances(chunks, ObjectArray): + _ensure_chunk_lengths_are_valid(chunks, expect_chunk_sizes) + resulting_stack = as_tensor(list(chain(*chunks)), dtype=object) + got_read_only = False + for chunk in chunks: + if chunk.is_read_only: + got_read_only = True + break + if got_read_only: + resulting_stack = resulting_stack.get_read_only_view() + elif _all_are_instances(chunks, TensorFrame): + _ensure_chunk_lengths_are_valid(chunks, expect_chunk_sizes) + for i_chunk, chunk in enumerate(chunks): + if i_chunk == 0: + resulting_stack = chunk + else: + resulting_stack = resulting_stack.vstack(chunk) + else: + raise TypeError("Encountered some unsupported types in the chunk, or their types are inconsistent") + + return resulting_stack + + +def _keys_of_all_dicts(dicts: Iterable[Mapping]) -> list: + """ + Get the keys of all dictionary-like objects. + + Args: + dicts: An iterable of dictionary-like objects (i.e. of instances of + `collections.abc.Mapping`). + Returns: + The keys of the given dictionary-like objects. + Raises: + KeyError: if the keys of the given dictionary-like objects are not + consistent. + ValueError: if the given iterable does not provide any dictionary. + """ + key_list: list | None = None + key_set: set | None = None + for d in dicts: + if key_set is None: + key_list = list(d.keys()) + key_set = set(d.keys()) + else: + if set(d.keys()) != key_set: + raise KeyError("The dictionaries have inconsistent keys") + if key_list is None: + raise ValueError("Cannot get the keys from an empty iterable of dictionaries") + return key_list + + +def _stack_chunked_dicts( + chunks: Sequence[Mapping[Any, TensorLike]], *, expect_chunk_sizes: Sequence[int] | None = None +) -> dict[Any, TensorLike]: + """ + Combine multiple dictionaries of chunked tensors into a single dictionary. + + Args: + chunks: A sequence of dictionary-like objects (i.e. of instances of + `collections.abc.Mapping`), in which each dictionary-like object + represents a single chunk. + expect_chunk_sizes: If provided, each tensor/TensorFrame/ObjectArray + within the i-th dictionary-like object will be expected to have + its size equal to the i-th integer within `expect_chunk_sizes`. + Returns: + A dictionary in which all the tensors/arrays are combined. + """ + if not isinstance(chunks, Sequence): + raise TypeError(f"`chunks` was expected as a Sequence, but it was received as an instance of {type(chunks)}") + if not _all_are_instances(chunks, Mapping): + raise TypeError("Some or all of the elements within the given sequence are not dictionaries") + keys = _keys_of_all_dicts(chunks) + return { + k: _stack_chunked_tensors([dict_chunk[k] for dict_chunk in chunks], expect_chunk_sizes=expect_chunk_sizes) + for k in keys + } + + +def _length_of_all_sequences(sequences: Iterable[Sequence]) -> int: + """ + Get the length of the given sequences. + + Args: + sequences: An iterable of sequences. Within this iterable, each + sequence's length will be checked. + Returns: + The length of the given sequences. + Raises: + ValueError: if the given sequences have inconsistent lengths, or if the + iterable does not provide any sequence at all. + """ + n: int | None = None + for s in sequences: + if n is None: + n = len(s) + else: + if len(s) != n: + raise ValueError("The sequences have inconsistent lengths") + if n is None: + raise ValueError("Cannot get the sequence length from an empty iterable of sequences") + return n + + +def _stack_chunked_sequences( + chunks: Sequence[Sequence[TensorLike]], *, expect_chunk_sizes: Sequence[int] | None = None +) -> list[TensorLike]: + """ + Combine multiple sequences of chunked tensors into a single list. + + Args: + chunks: A sequence consisting of sequences of + tensors/`TensorFrame`s/`ObjectArray`s. + expect_chunk_sizes: If this is provided, the + tensors/`TensorFrame`s/`ObjectArray`s within the i-th sequence + is expected to have their lengths equal to the i-th integer within + `expect_chunk_sizes`. + Returns: + A list in which all the tensors/arrays are combined. + """ + if not isinstance(chunks, Sequence): + raise TypeError(f"`chunks` was expected as a Sequence, but it was received as an instance of {type(chunks)}") + if not _all_are_instances(chunks, Sequence): + raise TypeError("Some or all of the elements within the given sequence are not sequences") + + sequence_maker = list + for c in chunks: + if isinstance(c, tuple): + sequence_maker = tuple + if hasattr(c, "_fields"): + raise TypeError("Chunks in the form of named tuple are not supported") + + n = _length_of_all_sequences(chunks) + result = sequence_maker( + _stack_chunked_tensors([sequence_chunk[k] for sequence_chunk in chunks], expect_chunk_sizes=expect_chunk_sizes) + for k in range(n) + ) + + return result + + +def stack_chunks( + chunks: Sequence[Sequence[TensorLike]] | Sequence[Mapping[Any, TensorLike]] | Sequence[TensorLike], + *, + expect_chunk_sizes: Sequence[int] | None = None, +) -> list[TensorLike] | Mapping[Any, TensorLike] | TensorLike: + """ + Stack the given tensors/arrays across the given chunks. + + Each chunk can be a tensor or a TensorFrame or an ObjectArray, or a + dictionary-like object or a sequence consisting of + tensors/`TensorFrame`s/`ObjectArray`s. + + Args: + chunks: A sequence in which each item is a chunk. + expect_chunk_sizes: If this is given, the tensor/array sizes within + the i-th chunk will be expected to be equal to the i-th integer + within `expect_chunk_sizes`. + Returns: + A combined tensor/array, or a dictionary-like object or a sequence + which contains combined tensors/arrays. + Raises: + ValueError: if the tensor/array lengths do not match with what is + specified within `expect_chunk_sizes`. + """ + if _all_are_instances(chunks, (torch.Tensor, ObjectArray, TensorFrame)): + return _stack_chunked_tensors(chunks, expect_chunk_sizes=expect_chunk_sizes) + elif _all_are_instances(chunks, Sequence): + return _stack_chunked_sequences(chunks, expect_chunk_sizes=expect_chunk_sizes) + elif _all_are_instances(chunks, Mapping): + return _stack_chunked_dicts(chunks, expect_chunk_sizes=expect_chunk_sizes) + else: + raise TypeError( + "Received a sequence in which some or all elements have unsupported types," + " or in which the element types are inconsistent" + ) + + +class _FunctionWrapInfo(NamedTuple): + function: Callable + num_actors: str | int | None + chunk_size: int | None + num_gpus_per_actor: int | float | str | None + split_arguments: tuple[bool, ...] + devices: tuple[torch.device, ...] + + +class _LockForTheMainProcess: + """ + A lock (in the context of threading) that is meant for the main process. + + Just like a regular `threading.Lock`, the instances of this class can be + used with the help of a `with` statement: + + ```python + my_lock = _LockForTheMainProcess() + + ... + + with my_lock: + ... # critical actions go here + ``` + + The differences of this class from `threading.Lock` are as follows: + + - Instances of this class are picklable (no error will be raised). + - Although picklable, based on the assumption that this type of lock + is meant only for the main process, the locking capabilities + of the instances of this class will disappear once they are + pickled and unpickled. + - If the locking capabilities of an instance of this class have + disappeared, trying the `with` statement on them will cause an + error. + - Objects containing the instances of this class can be serialized + and distributed by the `ray` library. However, the actual locking + capabilities will be available only to the ones on the main process. + """ + + def __init__(self): + self._lock: Lock | None = Lock() + + def _ensure_lock_exists(self): + if self._lock is None: + selfname = type(self).__name__ + raise RuntimeError(f"This {selfname} was pickled and then unpickled. It cannot be used anymore as a lock.") + + def __enter__(self): + self._ensure_lock_exists() + self._lock.acquire() + + def __exit__(self, exc_type, exc_value, traceback): + self._ensure_lock_exists() + self._lock.release() + + def __getstate__(self) -> dict: + result = {} + for k, v in self.__dict__.items(): + if k == "_lock": + result[k] = None + else: + result[k] = v + return result + + +def _loosely_find_leftmost_dimension_size( + x: TensorLike | Sequence[TensorLike] | Mapping[Any, TensorLike], *, _recurse: bool = True +) -> int: + """ + Find the leftmost dimension of the given tensors/arrays. + + If `x` is given as a tensor or as a TensorFrame or as an ObjectArray, + its leftmost dimension's size will be returned. + If `x` is given as a sequence or as a dictionary-like object, the leftmost + dimension size of the first tensor/array encountered within it will be + returned. + + This function assumes that the tensor/array size consistency within the + given container is checked elsewhere. With this assumption in mind, + consistency check will not be performed by this function. + + Args: + x: A tensor or a TensorFrame or an ObjectArray, or a sequence or a + dictionary-like object containing tensors/arrays. + _recurse: For internal usage. + Returns: + An integer representing the leftmost dimension size. + """ + if isinstance(x, (ObjectArray, TensorFrame)): + return len(x) + elif isinstance(x, torch.Tensor): + return x.shape[0] + elif isinstance(x, (str, np.str_, bytes, bytearray)): + raise TypeError(f"Received a sequence of this unexpected type: {type(x)}") + elif isinstance(x, (Mapping, Sequence)): + if not _recurse: + raise TypeError("Found a container when expecting a tensor or a TensorFrame or an ObjectArray") + if len(x) == 0: + raise ValueError("Encountered an empty container, which is unexpected") + if isinstance(x, Mapping): + elements = x.values() + else: + elements = x + first_element = next(iter(elements)) + return _loosely_find_leftmost_dimension_size(first_element, _recurse=False) + else: + raise TypeError(f"Encountered an object of this unexpected type: {type(x)}") + + +class _Wrapped: + functions: dict[_FunctionWrapInfo, Callable] = {} + lock = _LockForTheMainProcess() + + +class _DistributedFunctionHandler(Problem): + """ + Handler for a function that is decorated via `@distribute`. + + Although this handler is not meant to express an optimization problem, + it is built as a subclass of `evotorch.Problem`, for taking advantage + of multi-actor parallelization capabilities of the Problem class. + + **How does it work internally?** + This handler declares itself as a dummy optimization problem which + requires parallelization. The configuration arguments it receives regarding + parallelization are passed to the initializer of its parent class, + `Problem`. Additionally, upon its initialization, it receives the original + form of the decorated function and stores a reference to that function + within itself. + + Once this handler receives a request to execute the referenced function + in a distributed (i.e. parallelized) manner (via its method + `call_wrapped_function`), it forces its superclass (Problem) to create + remote actors by performing a dummy solution batch evaluation. + Once the remote actors are created, the input arguments to the wrapped + function are split into chunks, those chunks are then sent to the + remote actors along with a request to apply the wrapped function on them, + and finally the results of the actors are collected and combined. + Note that the wrapped function is called by the actors in parallel, which + is the main goal of this handler. + """ + + def __init__( + self, + *, + function: Callable, + num_actors: str | int | None = None, + chunk_size: int | None = None, + num_gpus_per_actor: int | float | str | None, + split_arguments: tuple[bool, ...], + devices: tuple[torch.device, ...], + ): + """ + `__init__(...)`: Initialize the `_DistributedFunctionHandler`. + + Args: + function: The reference to the original form of the function + to be distributed across multiple remote actors. + num_actors: Number of remote actors. + chunk_size: Optionally, the size of a chunk as an integer. + If this is given, then the original arguments will be split + into chunks with at most this given size. + num_gpus_per_actor: Number of GPUs to be allocated by each actor. + split_arguments_into_chunks: A tuple of booleans, in which the i-th + boolean says if the i-th positional argument for the wrapped + function is expected as split into chunks (True), or is to be + duplicated for each remote actor (False). + If this is given as an empty tuple, it will be assumed that + all the positional arguments are to be split into chunks. + devices: A tuple of devices. If this tuple is not empty, then the + i-th actor will use the i-th device listed within `devices`. + If this argument is to be provided as a non-empty tuple, + and if `devices` are going to be other than just cpus, + then it is highly recommended to set `num_gpus_per_actor` as + "all", so that the same devices will be visible to all actors. + The `@distribute` decorator, when using this handler class + internally, automatically sets `num_gpus_per_actor` as "all" + when a non-empty `devices` argument is provided. + """ + self.__function = function + self.__chunk_size = chunk_size + self.__split_arguments = split_arguments + self.__devices = devices + self.__parallelized = False + self.__parallelization_lock = _LockForTheMainProcess() + self.__actor_pool = None + + super().__init__( + objective_sense="min", + solution_length=2, + initial_bounds=(-1.0, 1.0), + dtype=torch.float32, + device="cpu", + num_actors=num_actors, + num_gpus_per_actor=num_gpus_per_actor, + store_solution_stats=False, + ) + + def _evaluate_batch(self, x: SolutionBatch): + """ + Just a filler batch evaluation procedure. + """ + z = torch.zeros(len(x), dtype=x.eval_dtype, device=x.device) + x.set_evals(z) + + def _ensure_dummy_problem_is_parallelized(self): + """ + Internal method for ensuring that the remote actors are created. + """ + if self.is_remote: + # If we are on a remote actor, this check is not necessary. We just exit the function. + return + + with self.__parallelization_lock: + if not self.__parallelized: + # This is the case where the problem has not been parallelized yet (i.e. we do not have actors yet). + # To trigger the creation of the actors, we generate a dummy SolutionBatch and evaluate it. + # The creation of the actors is then managed by the `evaluate` method of the parent Problem class. + dummy_batch = SolutionBatch(self, popsize=1) + self.evaluate(dummy_batch) + self.__parallelized = True + + if (self.actors is None) or (len(self.actors) < 2): + raise RuntimeError( + "Failed to create the distributed counterpart of the original function." + " Hint: this can happen if the arguments given to the `@distribute` decorator imply a non-distributed" + " environment, e.g., if one sets `num_actors='num_gpus'` when one has only 1 GPU," + " or if one sets `num_actors` as an integer that is smaller than 2." + ) + + # NOTE: do we need this, or could we actually use the actor pool of the underlying Problem? + self.__actor_pool = ActorPool(self.actors) + + def _iter_split_arguments(self, args: Sequence): + num_split_arguments = len(self.__split_arguments) + if num_split_arguments == 0: + for _ in range(len(args)): + yield True + else: + if num_split_arguments != len(args): + raise TypeError( + f"The number of received positional arguments ({len(args)})" + f" is different than what is expected ({len(self.__split_arguments)})" + ) + + for split_arg in self.__split_arguments: + yield split_arg + + def _call_wrapped_function_remotely(self, task_index: int, args: tuple) -> tuple[int, Any]: + """ + Internal helper method for calling the wrapped function on an actor. + + Args: + task_index: The index of the task. + args: Positional arguments to be passed to the wrapped function. + The positional arguments that were marked to be split into + chunks will be moved to the accelerator device associated with + this actor. + Returns: + A tuple in the form `(task_index, result)` where `task_index` is + the index of the task that was given, and `result` is the result + of the wrapped function, moved back to the cpu. + """ + + if self.is_main: + raise RuntimeError("This function should not be executed from the main actor") + + num_explicit_devices = len(self.__devices) + prepared_args = [] + + for split_arg, arg in zip(self._iter_split_arguments(args), args): + if split_arg: + if num_explicit_devices > 0: + target_device = self.__devices[self.actor_index % num_explicit_devices] + else: + target_device = self.aux_device + prepared_arg = move_shallow_container_to_device(arg, device=target_device) + else: + prepared_arg = arg + prepared_args.append(prepared_arg) + + result = self.__function(*prepared_args) + + # Move the result of this function back to the cpu, and return it. + return task_index, move_shallow_container_to_device(result, device="cpu") + + def call_wrapped_function(self, *args) -> Any: + """ + Run the wrapped function across the remote actors. + + Args: + args: Positional arguments to be passed to the wrapped function. + If this class was initialized with a non-empty + `split_arguments` tuple: i-th argument will be split into + chunks if the i-th element within `split_arguments` is True, + and the i-th chunk of arguments will be sent to the i-th actor. + If this class was initialized with an empty `split_arguments` + tuple: it will be assumed that all positional arguments are + to be split into chunks. + Note also that each actor will move its received chunks + to its own associated accelerator device before applying the + wrapped function on them. + Returns: + Combined result of the parallel computation of the remote actors. + The results will be on the cpu. + """ + + if len(args) == 0: + raise TypeError("Calling a distributed function without any positional arguments is not supported") + if not self.is_main: + raise RuntimeError("This method must be executed only from within the main actor") + self._ensure_dummy_problem_is_parallelized() + + first_split_arg_index = None + for i_arg, split_arg in enumerate(self._iter_split_arguments(args)): + if split_arg: + first_split_arg_index = i_arg + break + if first_split_arg_index is None: + raise ValueError( + "None of the arguments is marked for being split into chunks, which is not a supported configuration." + ) + + # split the arguments into chunks, BUT ONLY IF the argument is marked via `split_arguments` + chunked_args = split_arguments_into_chunks( + args, + list(self._iter_split_arguments(args)), + self.num_actors, + chunk_size=self.__chunk_size, + target_device="cpu", + ) + num_chunks = len(chunked_args[first_split_arg_index]) + + args_per_task = [[arg_chunk[i_task] for arg_chunk in chunked_args] for i_task in range(num_chunks)] + chunk_size_per_task = [ + _loosely_find_leftmost_dimension_size(args_per_task[i_task][first_split_arg_index]) + for i_task in range(num_chunks) + ] + + call_args_per_task = [ + ["_call_wrapped_function_remotely", [i_task, args_per_task[i_task]], {}] for i_task in range(num_chunks) + ] + + unordered_map_result = list( + self.__actor_pool.map_unordered( + (lambda actor, chunk: actor.call.remote(*chunk)), + call_args_per_task, + ) + ) + + assert len(unordered_map_result) == num_chunks + + ordered_map_result = [None for _ in range(num_chunks)] + for i_task, returned_chunk in unordered_map_result: + ordered_map_result[i_task] = returned_chunk + + # collect the remote results and combine the tensors + result = stack_chunks(ordered_map_result, expect_chunk_sizes=chunk_size_per_task) + + return result + + +class _DistributedFunction: + """ + A function that was decorated via `@distribute`. + + Please use the `@distribute` decorator instead of instantiating this class + manually. + """ + + def __init__(self, wrap_info: _FunctionWrapInfo): + self.wrap_info = wrap_info + self.problem = _DistributedFunctionHandler( + function=wrap_info.function, + num_actors=wrap_info.num_actors, + chunk_size=wrap_info.chunk_size, + num_gpus_per_actor=wrap_info.num_gpus_per_actor, + split_arguments=wrap_info.split_arguments, + devices=wrap_info.devices, + ) + if hasattr(self.wrap_info.function, "__evotorch_vectorized__"): + self.__evotorch_vectorized__ = self.wrap_info.function.__evotorch_vectorized__ + if hasattr(self.wrap_info.function, "__evotorch_pass_info__"): + self.__evotorch_pass_info__ = self.wrap_info.function.__evotorch_pass_info__ + self.__evotorch_distribute__ = True + + def __call__(self, *args): + return self.problem.call_wrapped_function(*args) + + +def _prepare_distributed_function( + function: Callable, + *, + split_arguments: Sequence[bool] | np.ndarray | torch.Tensor | None = None, + num_actors: int | str | None = None, + chunk_size: int | None = None, + num_gpus_per_actor: int | float | str | None = None, + devices: Sequence[torch.device | str], +) -> Callable: + if split_arguments is None: + split_arguments = tuple() + + if (not isinstance(split_arguments, Sequence)) or (isinstance(split_arguments, (str, np.str_, bytes, bytearray))): + raise TypeError( + f"`split_arguments` was expected as a Sequence of booleans, not as an instance of {repr(split_arguments)}" + ) + + if isinstance(split_arguments, tuple) and hasattr(split_arguments, "_fields"): + raise ValueError("`split_arguments` in the form of named tuples is not supported") + + if len(split_arguments) > 0: + # We are being extra careful here for ensuring that `split_arguments` is a sequence of booleans. + # We want to actively prevent unexpected behavior that could be caused by these mistakes: + # - providing argument indices instead of a sequence of booleans + # - providing one or more argument names as strings, instead of a sequence of booleans + _actual_split_arguments = [] + for split_arg in split_arguments: + if isinstance(split_arg, torch.Tensor) and (split_arg.ndim == 0): + _actual_split_arguments.append(bool(split_arg.to(device="cpu"))) + elif isinstance(split_arg, (bool, np.bool_)): + _actual_split_arguments.append(bool(split_arg)) + else: + raise TypeError("`split_arguments` was expected to contain booleans only") + split_arguments = tuple(_actual_split_arguments) + + if devices is None: + if (num_actors is None) or ((isinstance(num_actors, Integral)) and (num_actors <= 1)): + raise ValueError( + "The argument `devices` was received as None." + " When `devices` is None, `num_actors` is expected as an integer that is at least 2." + f" However, the given value of `num_actors` is {repr(num_actors)}." + ) + devices = tuple() + else: + if isinstance(devices, tuple) and hasattr(devices, "_fields"): + raise ValueError("`devices` in the form of a named tuple is not supported") + devices = tuple(torch.device(item) for item in devices) + num_devices = len(devices) + if num_devices == 0: + raise ValueError("`devices` cannot be given as an empty sequence") + if num_actors is None: + num_actors = num_devices + else: + raise ValueError( + "The `argument` devices was received as provided as a value other than None." + " When `devices` is not None, `num_actors` is expected to be left as None." + f" However, it was received as {repr(num_actors)}." + ) + + # We are given an explicit sequence of devices. + # Therefore, we assume that the actors must be able to see all the accelerator devices, + # and therefore override `num_gpus_per_actor` as "all". + if num_gpus_per_actor is None: + num_gpus_per_actor = "all" + else: + raise ValueError( + "The `argument` devices was received as provided as a value other than None." + " When `devices` is not None, `num_gpus_per_actor` is expected to be left as None." + f" However, it was received as {repr(num_gpus_per_actor)}." + ) + + # Prepare a wrap_info tuple which stores information about which function was wrapped with what configuration. + wrap_info = _FunctionWrapInfo( + function=function, + split_arguments=split_arguments, + num_actors=num_actors, + chunk_size=chunk_size, + num_gpus_per_actor=num_gpus_per_actor, + devices=devices, + ) + + with _Wrapped.lock: + if wrap_info in _Wrapped.functions: + # According to our global wrapped functions dictionary, if this particular function was wrapped before + # with these exact settings, we return the already wrapped version of the function. + result = _Wrapped.functions[wrap_info] + else: + # If this is the first time we are wrapping this function with these settings, then we create a wrapped + # version of this function, and put it into our global wrapped functions dictionary. + result = _DistributedFunction(wrap_info) + _Wrapped.functions[wrap_info] = result + + return result + + +class DecoratorForDistributingFunctions: + """ + Parameterized wrapper for making distributed counterparts of functions. + + It is highly recommended to use the `@distribute` decorator instead. + """ + + def __init__( + self, + *, + split_arguments: Sequence[bool] | np.ndarray | torch.Tensor | None = None, + num_actors: str | int | None = None, + chunk_size: int | None = None, + num_gpus_per_actor: int | float | str | None = None, + devices: Sequence[bool] | None = None, + ): + self.split_arguments = split_arguments + self.num_actors = num_actors + self.chunk_size = chunk_size + self.num_gpus_per_actor = num_gpus_per_actor + self.devices = devices + + def __call__(self, function: Callable) -> Callable: + return _prepare_distributed_function( + function, + split_arguments=self.split_arguments, + num_actors=self.num_actors, + chunk_size=self.chunk_size, + num_gpus_per_actor=self.num_gpus_per_actor, + devices=self.devices, + ) diff --git a/src/evotorch/decorators.py b/src/evotorch/decorators.py index d8e40c0..473a5f7 100644 --- a/src/evotorch/decorators.py +++ b/src/evotorch/decorators.py @@ -14,8 +14,9 @@ """Module defining decorators for evotorch.""" +from collections.abc import Iterable, Sequence from numbers import Number -from typing import Callable, Iterable, Optional, Union +from typing import Callable, Optional, Union import numpy as np import torch @@ -208,406 +209,661 @@ def __init__(self, obs_length: int, act_length: int, **kwargs): return _simple_decorator("__evotorch_pass_info__", args, decorator_name="pass_info") -def on_device(device: Device) -> Callable: +def vectorized(*args) -> Callable: """ - Decorator that informs a problem object that this function wants to - receive its solutions on the specified device. + Decorates a fitness function so that the problem object (which can be an instance + of [evotorch.Problem][evotorch.core.Problem]) will send the fitness function a 2D + tensor containing all the solutions, instead of a 1D tensor containing a single + solution. - What this decorator does is that it injects a `device` attribute onto - the decorated callable object. Then, this callable object itself is - returned. Upon seeing the `device` attribute, the `evaluate(...)` method - of the [Problem][evotorch.core.Problem] object will attempt to move the - solutions to that device. + What this decorator does is that it adds the decorated fitness function a new + attribute named `__evotorch_vectorized__`, the value of this new attribute being + True. Upon seeing this new attribute, the problem object will send this function + multiple solutions so that vectorized operations on multiple solutions can be + performed by this fitness function. - Let us imagine a fitness function `f` whose definition looks like: + Let us imagine that we have the following fitness function which works on a + single solution `x`, and returns a single fitness value: ```python import torch def f(x: torch.Tensor) -> torch.Tensor: - return torch.sum(x, dim=-1) + return torch.sum(x**2) ``` - In its not-yet-decorated form, the function `f` would be given `x` on the - default device of the associated problem object. However, if one decorates - `f` as follows: + ...and let us now define the optimization problem associated with this fitness + function: ```python - from evotorch.decorators import on_device + p1 = Problem("min", f, initial_bounds=(-10.0, 10.0), solution_length=5) + ``` + + While the fitness function `f` and the definition `p1` form a valid problem + description, it does not use PyTorch to its full potential in terms of performance. + If we were to request the evaluation results on a population of solutions via + `p1.evaluate(population)`, `p1` would use a classic `for` loop to evaluate every + single solution within `population` one by one. + We could greatly increase our performance by: + (i) re-defining our fitness function in a vectorized manner, i.e. in such a way + that it will operate on many solutions and compute all of their fitnesses at once; + (ii) label our fitness function via `@vectorized`, so that the problem object + will be aware that this new fitness function expects `n` solutions and returns + `n` fitnesses. The re-designed and labeled fitness function looks like this: + ```python + from evotorch.decorators import vectorized - @on_device("cuda:0") - def f(x: torch.Tensor) -> torch.Tensor: - return torch.sum(x, dim=-1) + + @vectorized + def f2(x: torch.Tensor) -> torch.Tensor: + return torch.sum(x**2, dim=-1) + ``` + + The problem description for `f2` is: + + ```python + p2 = Problem("min", f2, initial_bounds=(-10.0, 10.0), solution_length=5) ``` - then the Problem object will first move `x` onto the device cuda:0, and - then will call `f`. + In this last example, `p2` will realize that `f2` is decorated via `@vectorized`, + and will send it `n` solutions, and will receive and process `n` fitnesses. + """ + return _simple_decorator("__evotorch_vectorized__", args, decorator_name="vectorized") + + +def on_device( # noqa: C901 + *positional_args, + move_only_from_cpu: bool = False, + chunk_size: int | None = None, + device: Device | None = None, +) -> Callable: + """ + Transform a function so that it will compute on the specified device. + + A function decorated via `@on_device` will first move its positional + arguments to the specified device, then perform the operations listed + within the body of the original function definition, and then move + the result back to the most encountered device within its arguments. + + For a function to be decorated via `@on_device`, the assumption is that + its positional arguments and its output are of these types: + + - Pytorch tensor + - `ReadOnlyTensor` + - `TensorFrame` + - `ObjectArray` + - shallow (non-nested) sequence or dictionary-like container consisting of + objects that are instances of the types listed above + + Additionally, a `device` attribute is added onto the decorated counterpart + of the function. This `device` attribute is not meant for changing, but for + informing an observer regarding where the computation will take place. + + **Note.** + Although an `on_device`-decorated function moves its arguments to the + specified target device for encouraging the computation to take place on + that device, it is still possible for the inner body of the function to + move the tensors to any device. + + **Special behavior for evaluation methods of Problem objects.** + In addition to simple functions, these specific methods of a `Problem` + class can be decorated via `@on_device`: - This decorator is useful on multi-GPU settings. For details, please see - the following example: + - `_evaluate` + - `_evaluate_batch` + + If the decorated function receives a Problem object as its first argument, + and a Solution or a SolutionBatch as its second argument, the decorator + will assume that the decorated function is one of the methods listed above, + and will do nothing other than simply passing the arguments to the original + version of the decorated function. Instead, it is the `Problem` object + which moves the solutions to the correct device by looking at the `device` + attribute created by the `@on_device` decorator. + + Decorating arbitrary methods (other than these solution or batch evaluation + methods of the `Problem` class) is not supported. + + **Example usage 1.** + + Assuming that the cuda device is available: ```python - from evotorch import Problem from evotorch.decorators import on_device @on_device("cuda") - def f(x: torch.Tensor) -> torch.Tensor: ... + def my_function(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + # Thanks to the decorator, x and y should be on the 'cuda' device. + result = x + y - - problem = Problem( - "min", - f, - num_actors=4, - num_gpus_per_actor=1, - device="cpu", - ) + return result # the result will be moved back to the most encountered + # device among the original x and y tensors. ``` - In the example code above, we assume that there are 4 GPUs available. - The main device of the problem is "cpu", which means the populations - will be generated on the cpu. When evaluating a population, the population - will be split into 4 subbatches (because we have 4 actors), and each - subbatch will be sent to an actor. Thanks to the decorator `@on_device`, - the [Problem][evotorch.core.Problem] instance on each actor will first move - its [SolutionBatch][evotorch.core.SolutionBatch] to the cuda device visible - to its actor, and then the fitness function `f` will perform its evaluation - operations on that [SolutionBatch][evotorch.core.SolutionBatch] on the - the visible cuda. In summary, the actors will use their associated cuda - devices to evaluate the fitnesses of the solutions in parallel. + **Example usage 2.** - This decorator can also be used to decorate the method `_evaluate` or - `_evaluate_batch` belonging to a custom subclass of - [Problem][evotorch.core.Problem]. Please see the example below: + Assuming that the cuda device is available: ```python - from evotorch import Problem + import torch + from evotorch.decorators import on_device + from evotorch import Problem, SolutionBatch - class MyCustomProblem(Problem): + class SphereProblem(Problem): def __init__(self): super().__init__( - ..., - device="cpu", # populations will be created on the cpu - ..., + objective_sense="min", + solution_length=20, + initial_bounds=(-1.0, 1.0), + dtype=torch.float32, + device="cpu", # the populations are to be stored on the cpu ) - @on_device("cuda") # fitness evaluations will happen on cuda - def _evaluate_batch(self, solutions: SolutionBatch): - fitnesses = ... - solutions.set_evals(fitnesses) + @on_device("cuda") + def _evaluate_batch(self, batch: SolutionBatch): + # Upon seeing that this method is decorated by `@on_device`, + # the `Problem` object will move the `batch` to the cuda device + # while calling this method. + # Therefore, the computation below is expected to happen on cuda. + evals = torch.sum(batch.values**2.0, dim=-1) + batch.set_evals(evals) ``` - The attribute `device` that is added by this decorator can be used to - query the fitness device, and also to modify/update it: + **Specifying which arguments are to be moved to the target device.** + One might want to decorate, via `on_device` a function whose arguments + are not all tensors. For example, consider the following function: ```python - @on_device("cpu") - def f(x: torch.Tensor) -> torch.Tensor: ... + def my_fn(x: torch.Tensor, s: str) -> torch.Tensor: ... + ``` + As can be seen, the example function above has an argument `x` which can + be transferred to a computation device, but has another argument `s` whose + type is `str` and therefore cannot be subject to moving operations. + To tell `on_device` which arguments are subject to moving operations, + one can decorate `my_fn` like this: - print(f.device) # Prints: torch.device("cpu") - f.device = "cuda:0" # Evaluations will be done on cuda:0 from now on + ```python + @on_device("cuda", True, False) + def my_fn(x: torch.Tensor, s: str) -> torch.Tensor: ... ``` - Args: - device: The device on which the decorated fitness function will work. - """ + or alternatively, like this: - # Take the `torch.device` counterpart of `device` - device = torch.device(device) - - def decorator(fn: Callable) -> Callable: - setattr(fn, "__evotorch_on_device__", True) - setattr(fn, "device", device) - return fn + ```python + @on_device(True, False, device="cuda") + def my_fn(x: torch.Tensor, s: str) -> torch.Tensor: ... + ``` - return decorator + The first boolean (True) says that the first positional argument (x) is to + be moved to cuda. The second boolean (False) says that the second + positional argument (s) is NOT to be operated on by `on_device` and to be + passed as it is. + **Moving the input tensors in chunks.** + Let us imagine that we have a function that we want to run on cuda, and + that this function has high memory demands. For example, if the input + tensor has the leftmost dimension 10, it will work, but it will raise + a memory error if the input tensor's leftmost dimension is larger than + 10. In such cases, we can use the keyword argument `chunk_size`, so that + `on_device` will split the input tensors into chunks, and run the + underlying function on each chunk. Example usage: -def on_cuda(*args) -> Callable: - """ - Decorator that informs a problem object that this function wants to - receive its solutions on a cuda device (optionally of the specified - cuda index). + ```python + @on_device("cuda", chunk_size=10) + def memory_hungry_fn(x: torch.Tensor) -> torch.Tensor: ... + ``` - Decorating a fitness function like this: + If the function has some non-tensor arguments, those non-tensor arguments + can be marked so that they will not be subject to chunking: - ``` - @on_cuda - def f(...): - ... + ```python + @on_device("cuda", True, False, chunk_size=10) + # alternatively: @on_device(True, False, chunk_size=10, device="cuda") + def memory_hungry_fn2(x: torch.Tensor, s: str) -> torch.Tensor: ... ``` - is equivalent to: + In the case of the decoration of `memory_hungry_fn2`, the first boolean + (True) tells that the first positional argument (x) is to be split into + chunks of size 10, and then each chunk is to be moved to cuda. + The second boolean (False) tells that the second positional argument (s) + is not to be split into chunks, and not to be moved to any device, + but to be passed as it is, for every chunk of x. - ``` - @on_device("cuda") - def f(...): - ... + **Using `on_device` in its inline form.** + Instead of as a decorator, one may use `on_device` as an immediate + functional transformation tool. Examples: + + ```python + transformed_function = on_device(existing_function, device="cuda") ``` - Decorating a fitness function like this: + or if you want to mark the positional arguments of the existing + function: - ``` - @on_cuda(0) - def f(...): - ... + ```python + transformed_function2 = on_device(existing_function2, (True, False), device="cuda") ``` - is equivalent to: + Inline functional transformation with chunk size: - ``` - @on_device("cuda:0") - def f(...): - ... + ```python + transformed_function3 = on_device( + existing_function3, (True, False), chunk_size=10, device="cuda" + ) ``` - Please see the documentation of [on_device][evotorch.decorators.on_device] - for further details. - Args: - args: An optional positional arguments using which one can specify - the index of the cuda device to use. + device: The device to which the arguments will be moved. + move_only_from_cpu: If this True, only the tensors which are on the + cpu will be moved to the specified target tensor. + chunk_size: Optionally an integer. If provided as a positive integer, + the arguments will be split into chunks of this given size, + and those chunks will be sent to the target device and sent to + the decorated function one by one, and then the results will be + combined and returned. Please note that, for this feature to work, + the decorated function's arguments and result must have the same + leftmost dimension size. """ - # Get the number of arguments - nargs = len(args) + from ._distribute import _loosely_find_leftmost_dimension_size, split_arguments_into_chunks, stack_chunks + from .core import Problem, Solution, SolutionBatch + from .tools._shallow_containers import most_favored_device_among_arguments, move_shallow_container_to_device - if nargs == 0: - # If the number of arguments is 0, then we assume that we are in this situation: - # - # @on_cuda() - # def f(...): - # ... - # - # There is no specified index, and we are not yet given which object to decorate. - # Therefore, we set both of them as None. - index = None - fn = None - elif nargs == 1: - # The number of arguments is 1. We begin by storing that single argument using a variable named `arg`. - arg = args[0] - - if isinstance(arg, Callable): - # If the argument is a callable object, we assume that we are in this situation: - # - # @on_cuda - # def f(...): - # ... + if (len(positional_args) >= 1) and isinstance(positional_args[0], Callable): + complain_about_args = (len(positional_args) not in (1, 2)) or ( + (len(positional_args) == 2) and (not isinstance(positional_args[1], tuple)) + ) + if complain_about_args: + raise TypeError( + "The first argument of `on_device` is given as a callable object." + " In this situation, it is assumed that the user is using `on_device` not in its decorator" + " form, but in its inline function transformation form." + " The interface for inline transformation is:" + " `on_device(my_function, device=...)` or `on_device(my_function, tuple_of_booleans, device=...)`" + " where `tuple_of_booleans` is a tuple specifying which argument is to be moved to the device" + " (and to be split into chunks, if the keyword argument `chunk_size` is also provided)." + " However, the provided positional arguments do not seem to match the interface of the" + " inline transformation." + ) + function_to_transform = positional_args[0] + if len(positional_args) == 1: + arguments_to_process = tuple() + else: + arguments_to_process = positional_args[1] + return on_device( + *arguments_to_process, + move_only_from_cpu=move_only_from_cpu, + chunk_size=chunk_size, + device=device, + )(function_to_transform) + + if device is None: + device = torch.device(positional_args[0]) + positional_args = positional_args[1:] + + process_args = tuple(bool(positional_arg) for positional_arg in positional_args) + if len(process_args) == 0: + process_args = None + + # Make sure that the device is expressed as an instance of `torch.device` + device = torch.device(device) - # We are not given a cuda index - index = None + def decorator(original_behavior: Callable) -> Callable: - # We are given our function to decorate. We store that function using a variable named `fn`. - fn = arg - else: - # If the argument is not a callable object, we assume that it is a cuda index, and that we are in the - # following situation: - # - # @on_cuda(index) - # def f(...): - # ... - - # We are given a cuda index. After making sure that it is an integer, we store it by a variable named - # `index`. - index = int(arg) - - # At this moment, we do not know the function that is being decorated. So, we set `fn` as None. - fn = None - else: - # If the number of arguments is neither 0 nor 1, then this is an unexpected case. - # We raise an error to inform the user. - raise TypeError("`on_cuda(...)` received invalid number of arguments") + def modified_behavior(*args) -> Callable: - # Prepare the device as "cuda" - device_str = "cuda" + is_evaluation_method = False + if isinstance(args[0], Problem): + if (len(args) == 2) and isinstance(args[1], (Solution, SolutionBatch)): + if chunk_size is not None: + raise ValueError( + "When decorating a `Problem` method via `@on_device` (or `@on_aux_device` or `@on_cuda`)," + " `chunk_size` is not supported" + ) + is_evaluation_method = True + else: + raise TypeError( + " The function decorated by `@on_device` (or `@on_aux_device` or `@on_cuda`) has received" + " a Problem object as its first argument." + " In this case, it is assumed that the decorated function is an overriden version" + " of the method `Problem._evaluate(self, solution: Solution)`" + " or `Problem._evaluate_batch(self, batch: SolutionBatch)`." + " However, either the number of arguments or the type of the received non-self argument" + " is unexpected." + ) - if index is not None: - # If a cuda index is given, then we add ":N" (where N is the index) to the end of `device_str`. - device_str += ":" + str(index) + if is_evaluation_method: + # This seems to be an evaluation method (like, e.g. _evaluate_batch). + # In this case, we assume that the Problem object, while calling this method, already saw the + # `device` attribute of the decorated function, and did the necessary move operations on the + # solution batch. + # So, we just pass the positional arguments to the original function: + return original_behavior(*args) + + num_args = len(args) + if process_args is None: + which_args = [True for _ in range(num_args)] + else: + which_args = process_args + + # Get the most favored device among the tensors of the received arguments. + # This most favored device is the target device for the produced output. + result_device = most_favored_device_among_arguments( + [args[i_arg] for i_arg in range(num_args) if which_args[i_arg]], slightly_favor_cpu=True + ) - # Prepare the decorator function which, upon being called with a function argument, wraps that function. - decorator = on_device(device_str) + first_process_arg_index = None + for i_arg, process_this_arg in enumerate(which_args): + if process_this_arg: + first_process_arg_index = i_arg + break + if first_process_arg_index is None: + raise ValueError( + "None of the arguments is marked for moving to a device, which is not a supported configuration." + ) + + if chunk_size is None: + # Move each argument to the target device, and apply the wrapped function on the moved data. + result_value = original_behavior( + *[ + ( + move_shallow_container_to_device( + args[i_arg], device=device, move_only_from_cpu=move_only_from_cpu + ) + if which_args[i_arg] + else args[i_arg] + ) + for i_arg in range(num_args) + ] + ) + # Move the result back to the most favored device among the input arguments. + result_value = move_shallow_container_to_device(result_value, device=result_device) + else: + chunked_args = split_arguments_into_chunks(args, which_args, 1, chunk_size=chunk_size) + num_chunks = len(chunked_args[first_process_arg_index]) + args_per_task = [[arg_chunk[i_task] for arg_chunk in chunked_args] for i_task in range(num_chunks)] + chunk_size_per_task = [ + _loosely_find_leftmost_dimension_size(args_per_task[i_task][first_process_arg_index]) + for i_task in range(num_chunks) + ] + result_value = stack_chunks( + [ + move_shallow_container_to_device( + original_behavior( + *[ + ( + move_shallow_container_to_device( + task_args[i_arg], device=device, move_only_from_cpu=move_only_from_cpu + ) + if which_args[i_arg] + else task_args[i_arg] + ) + for i_arg in range(num_args) + ] + ), + device=result_device, + ) + for task_args in args_per_task + ], + expect_chunk_sizes=chunk_size_per_task, + ) + + # Finally, we return the result here. + return result_value + + if hasattr(original_behavior, "__evotorch_vectorized__"): + modified_behavior.__evotorch_vectorized__ = original_behavior.__evotorch_vectorized__ + modified_behavior.__evotorch_on_device__ = True + modified_behavior.device = device + if move_only_from_cpu: + modified_behavior.__evotorch_move_only_from_cpu__ = True + return modified_behavior - # If the function that is being decorated is not known yet (i.e. if `fn` is None), then we return the - # decorator function. If the function is known, then we decorate and return it. - return decorator if fn is None else decorator(fn) + return decorator def on_aux_device(*args) -> Callable: """ - Decorator that informs a problem object that this function wants to - receive its solutions on the auxiliary device of the problem. - - According to its default (non-overriden) implementation, a problem - object returns `torch.device("cuda")` as its auxiliary device if - PyTorch's cuda backend is available and if there is a visible cuda - device. Otherwise, the auxiliary device is returned as - `torch.device("cpu")`. - The auxiliary device is meant as a secondary device (in addition - to the main device reported by the problem object's `device` - attribute) used mainly for boosting the performance of fitness - evaluations. - This decorator, therefore, tells a problem object that the fitness - function requests to receive its solutions on this secondary device. - - What this decorator does is that it injects a new attribute named - `__evotorch_on_aux_device__` onto the decorated callable object, - then sets that new attribute to `True`, and then return the decorated - callable object itself. Upon seeing this new attribute with the - value `True`, a [Problem][evotorch.core.Problem] object will attempt - to move the solutions to its auxiliary device before calling the - decorated fitness function. - - Let us imagine a fitness function `f` whose definition looks like: + Transform a function so that it will compute on the auxiliary device. + + By default, the auxiliary device is cuda if cuda is available, and + cpu if cuda is not available. + + A function decorated via `@on_aux_device` will first move its positional + arguments to the auxiliary device if their original device is the cpu, + then perform the operations listed within the body of the original function + definition, and then move the result back to the most encountered device + within its arguments. + + For a function to be decorated via `@on_aux_device`, the assumption is that + its positional arguments and its output are of these types: + + - Pytorch tensor + - `ReadOnlyTensor` + - `TensorFrame` + - `ObjectArray` + - shallow (non-nested) sequence or dictionary-like container consisting of + objects that are instances of the types listed above + + Additionally, a `device` attribute is added onto the decorated counterpart + of the function. This `device` attribute is not meant for changing, but for + informing an observer regarding where the computation will take place. + An attribute `__evotorch_on_aux_device__=True` is also registered to the + decorated function, to inform to an outside observer that the function is + decorated via `@on_aux_device`. + + **Note.** + Although an `on_aux_device`-decorated function moves its cpu-residing + arguments to the auxiliary device for encouraging the computation to take + place on that auxiliary device, it is still possible for the inner body of + the function to move the tensors to any device. + + **Special behavior for evaluation methods of Problem objects.** + In addition to simple functions, these specific methods of a `Problem` + class can be decorated via `@on_aux_device`: + + - `_evaluate` + - `_evaluate_batch` + + If the decorated function receives a Problem object as its first argument, + and a Solution or a SolutionBatch as its second argument, the decorator + will assume that the decorated function is one of the methods listed above, + and will do nothing other than simply passing the arguments to the original + version of the decorated function. Instead, it is the `Problem` object + which moves the solutions to its own auxiliary device by looking at its own + `aux_device` property. + + Decorating arbitrary methods (other than these solution or batch evaluation + methods of the `Problem` class) is not supported. + + **Example usage 1.** ```python - import torch + from evotorch.decorators import on_device - def f(x: torch.Tensor) -> torch.Tensor: - return torch.sum(x, dim=-1) + @on_aux_device + def my_function(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + # Thanks to the decorator, x and y will be on the cuda device + # if cuda is available. + result = x + y + + return result # the result will be moved back to the most encountered + # device among the original x and y tensors. ``` - In its not-yet-decorated form, the function `f` would be given `x` on the - main device of the associated problem object. However, if one decorates - `f` as follows: + **Example usage 2.** ```python - from evotorch.decorators import on_aux_device + import torch + from evotorch.decorators import on_device + from evotorch import Problem, SolutionBatch - @on_aux_device - def f(x: torch.Tensor) -> torch.Tensor: - return torch.sum(x, dim=-1) + class SphereProblem(Problem): + def __init__(self): + super().__init__( + objective_sense="min", + solution_length=20, + initial_bounds=(-1.0, 1.0), + dtype=torch.float32, + device="cpu", # the populations are to be stored on the cpu + ) + + @on_device("cuda") + def _evaluate_batch(self, batch: SolutionBatch): + # Upon seeing that this method is decorated by `@on_aux_device`, + # the `Problem` object will move the `batch` to the auxiliary + # device declared by its property named `aux_device`. + evals = torch.sum(batch.values**2.0, dim=-1) + batch.set_evals(evals) ``` + """ - then the Problem object will first move `x` onto its auxiliary device, - then will call `f`. + num_args = len(args) - This decorator is useful on multi-GPU settings. For details, please see - the following example: + if num_args == 0: + func_to_wrap = None + elif num_args == 1: + [func_to_wrap] = args + else: + raise TypeError("`on_aux_device` received an unexpected number of positional arguments") - ```python - from evotorch import Problem - from evotorch.decorators import on_aux_device + target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + def decorator(fn: Callable) -> Callable: + fn = on_device(target_device, move_only_from_cpu=True)(fn) + if hasattr(fn, "__evotorch_on_device__"): + del fn.__evotorch_on_device__ + fn.__evotorch_on_aux_device__ = True + return fn - @on_aux_device - def f(x: torch.Tensor) -> torch.Tensor: ... + result = decorator + if func_to_wrap is not None: + result = result(func_to_wrap) + return result - problem = Problem( - "min", - f, - num_actors=4, - num_gpus_per_actor=1, - device="cpu", - ) - ``` +def on_cuda(*args) -> Callable: + """ + Transform a function so that it will compute on the specified cuda device. - In the example code above, we assume that there are 4 GPUs available. - The main device of the problem is "cpu", which means the populations - will be generated on the cpu. When evaluating a population, the population - will be split into 4 subbatches (because we have 4 actors), and each - subbatch will be sent to an actor. Thanks to the decorator `@on_aux_device`, - the [Problem][evotorch.core.Problem] instance on each actor will first move - its [SolutionBatch][evotorch.core.SolutionBatch] to its auxiliary device - visible to the actor, and then the fitness function will perform its - fitness evaluations on that device. In summary, the actors will use their - associated auxiliary devices (most commonly "cuda") to evaluate the - fitnesses of the solutions in parallel. + A function decorated via `@on_cuda` will first move its positional + arguments to the specified cuda device, then perform the operations listed + within the body of the original function definition, and then move + the result back to the most encountered device within its arguments. - This decorator can also be used to decorate the method `_evaluate` or - `_evaluate_batch` belonging to a custom subclass of - [Problem][evotorch.core.Problem]. Please see the example below: + For a function to be decorated via `@on_cuda`, the assumption is that + its positional arguments and its output are of these types: - ```python - from evotorch import Problem + - Pytorch tensor + - `ReadOnlyTensor` + - `TensorFrame` + - `ObjectArray` + - shallow (non-nested) sequence or dictionary-like container consisting of + objects that are instances of the types listed above + Additionally, a `device` attribute is added onto the decorated counterpart + of the function. This `device` attribute is not meant for changing, but for + informing an observer regarding where the computation will take place. - class MyCustomProblem(Problem): - def __init__(self): - super().__init__( - ..., - device="cpu", # populations will be created on the cpu - ..., - ) + **Note.** + Although an `on_cuda`-decorated function moves its arguments to the + specified cuda device for encouraging the computation to take place on + cuda, it is still possible for the inner body of the function to move + the tensors to any device. - @on_aux_device("cuda") # evaluations will be on the auxiliary device - def _evaluate_batch(self, solutions: SolutionBatch): - fitnesses = ... - solutions.set_evals(fitnesses) - ``` - """ - return _simple_decorator("__evotorch_on_aux_device__", args, decorator_name="on_aux_device") + **Special behavior for evaluation methods of Problem objects.** + In addition to simple functions, these specific methods of a `Problem` + class can be decorated via `@on_cuda`: + - `_evaluate` + - `_evaluate_batch` -def vectorized(*args) -> Callable: - """ - Decorates a fitness function so that the problem object (which can be an instance - of [evotorch.Problem][evotorch.core.Problem]) will send the fitness function a 2D - tensor containing all the solutions, instead of a 1D tensor containing a single - solution. + If the decorated function receives a Problem object as its first argument, + and a Solution or a SolutionBatch as its second argument, the decorator + will assume that the decorated function is one of the methods listed above, + and will do nothing other than simply passing the arguments to the original + version of the decorated function. Instead, it is the `Problem` object + which moves the solutions to the correct cuda device by looking at the + `device` attribute created by the `@on_cuda` decorator. - What this decorator does is that it adds the decorated fitness function a new - attribute named `__evotorch_vectorized__`, the value of this new attribute being - True. Upon seeing this new attribute, the problem object will send this function - multiple solutions so that vectorized operations on multiple solutions can be - performed by this fitness function. + Decorating arbitrary methods (other than these solution or batch evaluation + methods of the `Problem` class) is not supported. - Let us imagine that we have the following fitness function which works on a - single solution `x`, and returns a single fitness value: + **Example usage 1.** - ```python - import torch + Assuming that the cuda device is available: + ```python + from evotorch.decorators import on_device - def f(x: torch.Tensor) -> torch.Tensor: - return torch.sum(x**2) - ``` - ...and let us now define the optimization problem associated with this fitness - function: + @on_cuda # Note: could also be, e.g., @on_cuda(0) for 'cuda:0' + def my_function(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + # Thanks to the decorator, x and y should be on the 'cuda' device. + result = x + y - ```python - p1 = Problem("min", f, initial_bounds=(-10.0, 10.0), solution_length=5) + return result # the result will be moved back to the most encountered + # device among the original x and y tensors. ``` - While the fitness function `f` and the definition `p1` form a valid problem - description, it does not use PyTorch to its full potential in terms of performance. - If we were to request the evaluation results on a population of solutions via - `p1.evaluate(population)`, `p1` would use a classic `for` loop to evaluate every - single solution within `population` one by one. - We could greatly increase our performance by: - (i) re-defining our fitness function in a vectorized manner, i.e. in such a way - that it will operate on many solutions and compute all of their fitnesses at once; - (ii) label our fitness function via `@vectorized`, so that the problem object - will be aware that this new fitness function expects `n` solutions and returns - `n` fitnesses. The re-designed and labeled fitness function looks like this: + **Example usage 2.** + + Assuming that the cuda device is available: ```python - from evotorch.decorators import vectorized + import torch + from evotorch.decorators import on_device + from evotorch import Problem, SolutionBatch - @vectorized - def f2(x: torch.Tensor) -> torch.Tensor: - return torch.sum(x**2, dim=-1) + class SphereProblem(Problem): + def __init__(self): + super().__init__( + objective_sense="min", + solution_length=20, + initial_bounds=(-1.0, 1.0), + dtype=torch.float32, + device="cpu", # the populations are to be stored on the cpu + ) + + @on_cuda # Note: could also be, e.g., @on_cuda(0) for 'cuda:0' + def _evaluate_batch(self, batch: SolutionBatch): + # Upon seeing that this method is decorated by `@on_cuda`, + # the `Problem` object will move the `batch` to the cuda device + # while calling this method. + # Therefore, the computation below is expected to happen on cuda. + evals = torch.sum(batch.values**2.0, dim=-1) + batch.set_evals(evals) ``` + """ + num_args = len(args) - The problem description for `f2` is: + if num_args == 0: + func_to_wrap = None + target_device = torch.device("cuda") + elif num_args == 1: + [first_arg] = args + if isinstance(first_arg, Callable): + func_to_wrap = first_arg + target_device = torch.device("cuda") + else: + func_to_wrap = None + target_device = torch.device("cuda", int(first_arg)) + else: + raise TypeError("`on_cuda` received an unexpected number of positional arguments") - ```python - p2 = Problem("min", f2, initial_bounds=(-10.0, 10.0), solution_length=5) - ``` + decorator = on_device(target_device) - In this last example, `p2` will realize that `f2` is decorated via `@vectorized`, - and will send it `n` solutions, and will receive and process `n` fitnesses. - """ - return _simple_decorator("__evotorch_vectorized__", args, decorator_name="vectorized") + if func_to_wrap is None: + return decorator + else: + return decorator(func_to_wrap) def expects_ndim( # noqa: C901 @@ -694,6 +950,12 @@ def f(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: ... ) def expects_ndim_decorator(fn: Callable): + if hasattr(fn, "__evotorch_distribute__"): + raise ValueError( + "Cannot apply `@expects_ndim` or `@rowwise` on a function" + " that was previously subjected to `@distribute`" + ) + def expects_ndim_decorated(*args): # The inner class below is responsible for accumulating the dtype and device info of the tensors # encountered across the arguments received by the decorated function. @@ -702,8 +964,8 @@ def expects_ndim_decorated(*args): class tensor_info: # At first, we initialize the set of encountered dtype and device info as None. # They will be lazily filled if we ever need such information. - encountered_dtypes: Optional[set] = None - encountered_devices: Optional[set] = None + encountered_dtypes: set | None = None + encountered_devices: set | None = None @classmethod def update(cls): @@ -711,7 +973,7 @@ def update(cls): if (cls.encountered_dtypes is None) or (cls.encountered_devices is None): cls.encountered_dtypes = set() cls.encountered_devices = set() - for expected_arg_ndim, arg in zip(expected_ndims, args): + for expected_arg_ndims, arg in zip(expected_ndims, args): if (expected_arg_ndims is not None) and isinstance(arg, torch.Tensor): # If the argument has a declared expected ndim, and also if it is a PyTorch tensor, # then we add its dtype and device information to the sets `encountered_dtypes` and @@ -764,6 +1026,7 @@ def convert_scalar_to_tensor(cls, scalar: Number) -> torch.Tensor: if isinstance(scalar, (bool, np.bool_)): # If the given scalar argument is a boolean, we declare the dtype of its tensor counterpart as # torch.bool. + scalar = bool(scalar) dtype = torch.bool else: # If the given scalar argument is not a boolean, we declare the dtype of its tensor counterpart @@ -963,3 +1226,260 @@ def decorator(fn: Callable) -> Callable: # <- inner decorator return decorated return decorator(args[0]) if immediately_decorate else decorator + + +def distribute( + *arguments, + num_actors: str | int | None = None, + chunk_size: int | None = None, + num_gpus_per_actor: int | float | str | None = None, + devices: Sequence[bool] | None = None, +) -> Callable: + """ + Transform a function such that its computations are distributed. + + Let us assume that we have the following function which expects two tensors + as arguments, and returns a new tensor, with the constraint that the + leftmost dimension sizes of all these tensors (of its input arguments and + and of its returned output) are the same: + + ```python + def update_and_concat(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + assert a.ndim == 2 + assert b.ndim == 2 + + # ==== + # Let us imagine some very heavy computation here which modifies a and b + # such that their values are updated, but their sizes remain the same. + ... + # ==== + + return torch.hstack([a, b]) + ``` + + Let us now imagine that, because of the heavy computation part, we want to + run this function in a distributed manner, across two cuda devices. + To achieve this, we can decorate this function as follows: + + ```python + @distribute(devices=["cuda:0", "cuda:1"]) + def update_and_concat(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: ... + ``` + + The decorated version of this function, upon being called for the first + time, will do the following: + + - create two remote actors, one for `cuda:0`, one for `cuda:1`; + - split the input arguments `a` and `b` into 2 chunks (because of 2 actors) + along their leftmost dimensions; + - send the first chunk of arguments to the actor dedicated to `cuda:0` + and the second chunk of arguments to the actor dedicated to `cuda:1`; + - initiate parallelized computation across the actors (each actor + moving its received chunk of arguments to its assigned device); + - collect the resulting chunks produced by the actors and combine the + chunks. + + The finally collected and combined result is the output of the decorated + function. + + The following types are supported for splitting into chunks of arguments + and for combining to form the final output: + - `torch.Tensor` + - `evotorch.tools.ReadOnlyTensor` + - `evotorch.tools.ObjectArray` + - `evotorch.tools.TensorFrame` + - a (non-nested) dictionary-like object (i.e. Mapping) in which the values + are `Tensor`, `ReadOnlyTensor`, `ObjectArray` or `TensorFrame` + - a (non-nested) sequence in which the values are `Tensor`, + `ReadOnlyTensor`, `ObjectArray`, `TensorFrame` + + **Combining with other decorators.** + A function that was previously decorated via `@expects_ndim` or `@rowwise` + or `@torch.vmap` can be decorated via `@distribute`. However, the opposite + is NOT true (e.g. a function that was previously decorated via + `@distribute` cannot be then decorated via `@expects_ndim`). + + **Inline function transformation.** + The `distribute` function can also be used in this alternative form if + decoration is not desired: + + ```python + distributed_update_and_concat = distribute( + update_and_concat, devices=["cuda:0", "cuda:1"] + ) + ``` + + **Alternative ways of declaring number of actors.** + Like in the example above, if we have two cuda devices and we want to + explicitly target them, we decorate our function like this: + + ```python + @distribute(devices=["cuda:0", "cuda:1"]) + def update_and_concat(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: ... + ``` + + The devices do not have to be different. For example, for having 4 + actors which share the available CPUs, one could do: + + ```python + @distribute(devices=["cpu", "cpu", "cpu", "cpu"]) + def update_and_concat(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: ... + ``` + + For distributing the computation across 4 GPUs: + + ```python + @distribute(num_actors=4, num_gpus_per_actor=1) + def update_and_concat(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: ... + ``` + + For having two actors, each using the half of a single GPU: + + ```python + @distribute(num_actors=2, num_gpus_per_actor=0.5) + def update_and_concat(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: ... + ``` + + For having an actor for each available GPU: + + ```python + @distribute(num_actors="num_gpus", num_gpus_per_actor=1) + def update_and_concat(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: ... + ``` + + For having `n` actors, where `n` is the minimum between the number of + CPUs and the number of GPUs: + + ```python + @distribute(num_actors="num_devices", num_gpus_per_actor=1) + def update_and_concat(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: ... + ``` + + For having a CPU-only actor for each available CPU: + + ```python + @distribute(num_actors="num_cpus") # or: num_actors="max" + def update_and_concat(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + # Note: setting `num_actors` without setting `num_gpus_per_actor` + # will cause the actors to be CPU-only (they will not see the GPUs). + ... + ``` + + **Specifying which argument to split into chunks.** + Sometimes one has to distribute a function in which some of the arguments + are not tensors, but are flags to configure the behavior of the function. + While sending the arguments to remote actors, such flags are usually + expected to be duplicated, instead of being sent in chunks. + As an example, please take a look at the function below: + + ```python + def update_and_combine( + a: torch.Tensor, b: torch.Tensor, combine_how: str + ) -> torch.Tensor: + assert a.ndim == 2 + assert b.ndim == 2 + + # ==== + # Let us imagine some very heavy computation here which modifies a and b + # such that their values are updated, but their sizes remain the same. + ... + # ==== + + if combine_how == "add": + return a + b + elif combine_how == "hstack": + return torch.hstack([a, b]) + else: + raise ValueError("Unsupported combine_how value") + ``` + + In the case of this example, we inform `@distribute` that the first two + positional arguments are to be split into chunks, and the third positional + argument is to be duplicated: + + ```python + @distribute(True, True, False, devices=...) + def update_and_combine( + a: torch.Tensor, b: torch.Tensor, combine_how: str + ) -> torch.Tensor: ... + ``` + + Notice how `@distribute` is given booleans as positional arguments. + The first boolean (True) tells that the first argument of + `update_and_combine`, `a`, is to be split into chunks. + The second boolean (True) tells that the second argument of + `update_and_combine`, `b`, is to be split into chunks. + The third boolean (False) tells that the third argument of + `update_and_combine`, `combine_how`, is to be duplicated + (i.e. to be sent as it is, instead of being split into chunks). + + The non-decorator alternative looks like this: + + ```python + dist_update_and_combine = distribute( + update_and_combine, (True, True, False), devices=... + ) + ``` + + **Specifying a chunk size.** + The `@distribute` decorator has an optional integer argument named + `chunk_size`. If this is given, then the original arguments will be + split into chunks with at most this given size. + + Example: + + ```python + @distribute(devices=["cpu", "cpu"], chunk_size=10) + def function_to_be_distributed(x: torch.Tensor) -> torch.Tensor: ... + + + large_data = ... # some large tensor here + + # The call below will split `large_data` into chunks. + # Each chunk is a subtensor of `large_data`, and the leftmost dimension + # size of each chunk is at most 10. + # Parallelized processing of these chunks will be scheduled for the two + # available remote actors. + result = function_to_be_distributed(large_data) + ``` + + **Distributing across multiple computers.** + This `@distribute` decorator uses the `ray` library for parallelizing + the wrapped function. Thanks to this, if the program is placed upon + a `ray`-powered cluster consisting of multiple computers (and also + if the main program has addressed and initialized the `ray` cluster using + `ray.init` before executing this decorator), the computation of the + wrapped function will be distributed across all the devices that are + visible to the cluster. + + **NOTE.** + If a distributed counterpart of a function cannot be created due to its + distribution configuration (e.g. if one sets `num_actors` as 1 or 0, or if + one sets `num_actors` as `"num_gpus"` when there is only 1 GPU available), + an error will be raised. + """ + + from ._distribute import DecoratorForDistributingFunctions + + if (len(arguments) == 1) and isinstance(arguments[0], Callable): + function_to_decorate = arguments[0] + split_arguments = None + elif (len(arguments) == 2) and isinstance(arguments[0], Callable) and isinstance(arguments[1], tuple): + function_to_decorate = arguments[0] + split_arguments = arguments[1] + else: + function_to_decorate = None + split_arguments = arguments + + result = DecoratorForDistributingFunctions( + split_arguments=split_arguments, + num_actors=num_actors, + chunk_size=chunk_size, + num_gpus_per_actor=num_gpus_per_actor, + devices=devices, + ) + + if function_to_decorate is not None: + result = result(function_to_decorate) + + return result diff --git a/src/evotorch/tools/_shallow_containers.py b/src/evotorch/tools/_shallow_containers.py new file mode 100644 index 0000000..abf9c93 --- /dev/null +++ b/src/evotorch/tools/_shallow_containers.py @@ -0,0 +1,281 @@ +# Copyright 2025 NNAISENSE SA +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +Internal utility functions for operating on shallow (non-nested) containers. +""" + + +from collections.abc import Mapping, Sequence +from typing import Any + +import numpy as np +import torch + +from . import ObjectArray, TensorFrame + +TensorLike = torch.Tensor | TensorFrame | ObjectArray + + +def _move_tensorframe(x: TensorFrame, *, device: str | torch.device, move_only_from_cpu: bool = False) -> TensorFrame: + """ + Move a TensorFrame to the specified device. + + Args: + x: The TensorFrame to be moved to the specified device. + device: The target device. + move_only_from_cpu: If True, only columns that are on cpu will be moved. + Returns: + The new TensorFrame with its columns moved. + """ + + if move_only_from_cpu: + # This is the case where we only move the columns that are on the cpu. + device_to_be_enforced = x.has_enforced_device + original_is_read_only = x.is_read_only + + new_devices = set() # the set which is to store the devices after the moving operation + coldata = {} # the dictionary which is to store the new columns + for colname in x.columns: + # Move each column on the cpu to the target device + coldata[colname] = _move(x[colname], device=device, move_only_from_cpu=move_only_from_cpu) + # Add the new column's device into the set of new devices + new_devices.add(coldata[colname].device) + + device_kwarg = {} + if device_to_be_enforced and (len(new_devices) == 1): + # If the original TensorFrame has an enforced device, and the number of new devices is 1, + # then we declare this new device as the enforced device of the new TensorFrame. + device_kwarg["device"] = device + + # Prepare the new TensorFrame and return it. + return TensorFrame(coldata, read_only=original_is_read_only, **device_kwarg) + else: + # This is the case where we move all columns of the TensorFrame. + device_to_be_enforced = x.has_enforced_device + x = x.to(device=device) + if not device_to_be_enforced: + x = x.without_enforced_device() + return x + + +def _move(x: TensorLike, *, device: str | torch.device, move_only_from_cpu: bool = False) -> TensorLike: + """ + Move a tensor or a TensorFrame or an ObjectArray to the target device. + + Note that an ObjectArray cannot really be moved to any device other + than the cpu, so, any instance of ObjectArray is simply returned + as it is. + + Args: + x: A tensor, or a TensorFrame, or an ObjectArray subject to moving. + device: The target device. + move_only_from_cpu: If True, the tensor, or the columns of the given + TensorFrame will be moved only if they currently reside on the cpu. + If False, all the tensor data will be moved to the target device. + Returns: + The counterpart of the original tensor/array after the moving. + """ + if isinstance(x, TensorFrame): + return _move_tensorframe(x, device=device, move_only_from_cpu=move_only_from_cpu) + if isinstance(x, ObjectArray): + return x + + if move_only_from_cpu and (x.device != torch.device("cpu")): + # If we are to move the tensor only when it is on the cpu, and also we observe that it is NOT on the cpu, + # then we simply return the tensor as it is, without moving it. + return x + + return x.to(device=device) + + +def move_shallow_container_to_device( + x: TensorLike | Sequence[TensorLike] | Mapping[Any, TensorLike], + *, + device: str | torch.device, + move_only_from_cpu: bool = False, +) -> TensorLike | Sequence[TensorLike] | Mapping[Any, TensorLike]: + """ + Move a tensor or a shallow container of tensors to the given device. + + Args: + x: A tensor or a TensorFrame or an ObjectArray, or a dictionary-like + object or a sequence of tensors/arrays. Any encountered tensors + and `TensorFrame`s within `x` will be moved to the given `device`. + Any encountered ObjectArray within `x` will be put back as it is + (without raising any error), since an ObjectArray can reside only + on the cpu. + device: The target device that can be given as an instance of `str` + or `torch.device`. + move_only_from_cpu: If this is given as True, then the tensors/arrays + will be moved to the target device ONLY if they currently reside + on the cpu. + Returns: + The counterpart of `x` that resides on the given device. + """ + + def move(obj: TensorLike) -> TensorLike: + return _move(obj, device=device, move_only_from_cpu=move_only_from_cpu) + + if isinstance(x, (torch.Tensor, TensorFrame, ObjectArray)): + result = move(x) + elif isinstance(x, (str, np.str_, bytes, bytearray)): + raise TypeError(f"Cannot move an object of type {type(x)} to the device {device}") + elif isinstance(x, Mapping): + result = {} + for k, v in x.items(): + if isinstance(v, (torch.Tensor, TensorFrame, ObjectArray)): + result[k] = move(v) + else: + raise TypeError( + "While trying to move the tensors within a dictionary-like object," + f" encountered an element of this unexpected type: {type(v)}" + ) + elif isinstance(x, Sequence): + result = [] + for item in x: + if isinstance(item, (torch.Tensor, TensorFrame, ObjectArray)): + result.append(move(item)) + else: + raise TypeError( + "While trying to move the tensors within a sequence," + f" encountered an element of this unexpected type: {type(item)}" + ) + else: + raise TypeError(f"Cannot move an object of type {type(x)} to the device {device}") + + return result + + +def _update_dict_additively(left: dict, right: dict): + """ + Additively update the left dict using the items of the right dict. + + In more details, if a key within the right dictionary does not exist within + the left dictionary, that item is put into the left dictionary. On the + other hand, if a key within the right dictionary does exist within the left + dictionary, the value within the right dictionary is added (using the `+=` + operator) onto the value of the left dictionary. + + This function returns nothing, and the update is done in-place. + + Args: + left: The dictionary to be updated. + right: The dictionary whose values will be used to update the left + dictionary. + """ + for k, v in right.items(): + if k in left: + left[k] += v + else: + left[k] = v + + +def count_devices_within_shallow_container( + x: TensorLike | Mapping[Any, TensorLike] | Sequence[TensorLike], + *, + _already_within_container: bool = False, +) -> dict[torch.device, float]: + """ + Given a shallow (non-nested) container of tensors, count devices in it. + + The returned object is a dictionary in which each key is a `torch.device`, + and each value represents how many times a device is encountered. + + Args: + x: A shallow container in which the devices are to be counted. + _already_within_container: For internal usage. + Returns: + The dictionary which stores the device counts. + """ + + devices = {} + + if isinstance(x, (torch.Tensor, ObjectArray)): + _update_dict_additively(devices, {x.device: 1.0}) + elif isinstance(x, TensorFrame): + for col in x.columns: + _update_dict_additively(devices, {x[col].device: 1.0}) + elif isinstance(x, (Sequence, Mapping)): + if _already_within_container: + raise TypeError("Nested containers are not supported") + if isinstance(x, (str, np.str_, bytes, bytearray)): + raise TypeError(f"Unsupported type: {type(x)}") + if isinstance(x, tuple) and hasattr(x, "_fields"): + raise TypeError("Named tuples are not supported") + if isinstance(x, Mapping): + values_of_x = x.values() + else: + values_of_x = x + for v in values_of_x: + _update_dict_additively( + devices, + count_devices_within_shallow_container( + v, + _already_within_container=True, + ), + ) + else: + raise TypeError(f"Encountered an object of this unexpected type: {type(x)}") + + return devices + + +def most_favored_device_among_arguments( + args: Sequence[TensorLike | Mapping[Any, TensorLike] | Sequence[TensorLike]], + *, + slightly_favor_cpu: bool = True, +) -> torch.device: + """ + Given arguments in a tuple, find the most favored PyTorch device. + + It is expected that the arguments consist of PyTorch tensors or + `TensorFrame`s or `ObjectArray`s, or shallow (non-nested) sequences or + dictionary-like containers containing tensors and/or `TensorFrame`s + and/or `ObjectArray`s. + + Args: + x: A tuple of arguments. + slightly_favor_cpu: If this is given as True, and if there are + multiple devices that are encountered equally, and if cpu + is one of those devices, then cpu will be picked. + Returns: + The most favored torch.device. + """ + + weights = {} + + if not isinstance(args, Sequence): + raise TypeError(f"`args` was received as an instance of this unexpected type: {type(args)}") + if isinstance(args, tuple) and hasattr(args, "_fields"): + raise TypeError("Providing `args` as a named tuple is not supported") + + for arg in args: + _update_dict_additively( + weights, + count_devices_within_shallow_container(arg), + ) + + if slightly_favor_cpu: + _update_dict_additively(weights, {"cpu": 0.1}) + + device_with_max_weight = torch.device("cpu") + max_weight = 0.0 + for d, w in weights.items(): + if w > max_weight: + device_with_max_weight = d + max_weight = w + + return device_with_max_weight diff --git a/src/evotorch/tools/tensorframe.py b/src/evotorch/tools/tensorframe.py index 8db7370..a41582e 100644 --- a/src/evotorch/tools/tensorframe.py +++ b/src/evotorch/tools/tensorframe.py @@ -417,6 +417,14 @@ def __getattr__(self, column_name: str) -> torch.Tensor: else: raise AttributeError(column_name) + @property + def has_enforced_device(self) -> bool: + """ + True if this TensorFrame was initialized with a `device`. + False otherwise. + """ + return self.__device is not None + def without_enforced_device(self) -> "TensorFrame": """ Make a shallow copy of this TensorFrame without any enforced device. diff --git a/tests/test_decorators.py b/tests/test_decorators.py index c98d097..a10ff81 100644 --- a/tests/test_decorators.py +++ b/tests/test_decorators.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any - +import numpy as np import pytest import torch -from evotorch.decorators import on_aux_device, on_cuda, on_device, pass_info, vectorized +from evotorch.decorators import distribute, on_aux_device, on_cuda, on_device, pass_info, vectorized +from evotorch.tools import ObjectArray, as_tensor @pytest.mark.parametrize( @@ -44,7 +44,7 @@ def g(): assert getattr(g, attribute) is True -@pytest.mark.parametrize("decorator", [pass_info, on_aux_device, on_device, on_cuda, vectorized]) +@pytest.mark.parametrize("decorator", [pass_info, vectorized]) def test_decorating_fails_with_too_many_args(decorator): def g(x): pass @@ -53,17 +53,26 @@ def g(x): decorator("foo", 2)(g) -@pytest.mark.parametrize("decorator", [pass_info, on_aux_device, on_device("cpu"), on_cuda, vectorized]) +@pytest.mark.parametrize("decorator", [pass_info, on_aux_device, on_device("cpu"), vectorized]) def test_decorator_does_not_modify_function(decorator): - def g(): - return 42 + test_matrix = torch.LongTensor( + [ + [1, 2], + [3, 4], + ] + ) + + def g(x: torch.Tensor) -> torch.Tensor: + return 2 * x g = decorator(g) - assert g() == 42 + result = g(test_matrix).to(device="cpu") + assert bool(torch.all(result == (2 * test_matrix))) -@pytest.mark.parametrize("decorator", [pass_info, on_aux_device, on_device("cpu"), on_cuda, vectorized]) + +@pytest.mark.parametrize("decorator", [pass_info, vectorized]) def test_decorator_preserves_signature(decorator): def g(x: float, y: int) -> float: return x + y @@ -73,7 +82,7 @@ def g(x: float, y: int) -> float: assert g.__annotations__ == {"x": float, "y": int, "return": float} -@pytest.mark.parametrize("decorator", [pass_info, on_aux_device, on_device("cpu"), on_cuda, vectorized]) +@pytest.mark.parametrize("decorator", [pass_info, vectorized]) def test_decorator_preserves_docstring(decorator): def g(): """Docstring""" @@ -84,7 +93,7 @@ def g(): assert g.__doc__ == "Docstring" -@pytest.mark.parametrize("decorator", [pass_info, on_aux_device, on_device("cpu"), on_cuda, vectorized]) +@pytest.mark.parametrize("decorator", [pass_info, vectorized]) def test_decorator_preserves_name(decorator): def g(): pass @@ -124,3 +133,126 @@ def g(): assert hasattr(g, "device") assert g.device == torch.device(expected) + + +def test_on_device_moves_input_tensors(): + + @on_device("meta") + def f(x: torch.Tensor) -> torch.Tensor: + if x.device == torch.device("meta"): + x = torch.ones_like(x, device="cpu") + return torch.sum(x) + + input_tensor = torch.arange(10, dtype=torch.int64, device="cpu") + result = f(input_tensor) + + assert int(torch.sum(result)) == len(input_tensor) + + +@pytest.mark.parametrize("decoration_form", [True, False]) +def test_on_device_chunking(decoration_form: bool): + + input_tensor = torch.LongTensor( + [ + [1, 2, 3], + [4, 5, 6], + [10, 20, 30], + [40, 50, 60], + [-1, -2, -3], + ] + ) + + def f(x: torch.Tensor) -> torch.Tensor: + return torch.sum(x, dim=-1) + + chunk_size = 2 + if decoration_form: + + @on_device("cpu", chunk_size=chunk_size) + def chunking_f(x: torch.Tensor) -> torch.Tensor: + return f(x) + + else: + chunking_f = on_device(f, device="cpu", chunk_size=chunk_size) + + recombined_result = chunking_f(input_tensor) + expected_result = f(input_tensor) + + assert recombined_result.shape == expected_result.shape + assert bool(torch.all(recombined_result == expected_result)) + + +@pytest.mark.parametrize( + "decoration_form, distribute_config, chunk_size", + [ + (True, {"devices": ["cpu", "cpu"]}, None), + (True, {"num_actors": 2}, None), + (False, {"devices": ["cpu", "cpu"]}, None), + (False, {"num_actors": 2}, None), + (True, {"devices": ["cpu", "cpu"]}, 2), + (True, {"num_actors": 2}, 2), + (False, {"devices": ["cpu", "cpu"]}, 2), + (False, {"num_actors": 2}, 2), + ], +) +def test_distribute(decoration_form: bool, distribute_config: dict, chunk_size: int | None): + + input_tensor = torch.LongTensor( + [ + [1, 2, 3], + [4, 5, 6], + [10, 20, 30], + [40, 50, 60], + [-1, -2, -3], + [-4, -5, -6], + [-30, -60, -90], + ] + ) + + def f(x: torch.Tensor) -> torch.Tensor: + return torch.sum(x, dim=-1) + + if decoration_form: + + @distribute(**distribute_config) + def distributed_f(x: torch.Tensor) -> torch.Tensor: + return f(x) + + else: + distributed_f = distribute(f, **distribute_config) + + recombined_result = distributed_f(input_tensor) + expected_result = f(input_tensor) + + assert recombined_result.shape == expected_result.shape + assert bool(torch.all(recombined_result == expected_result)) + + +def test_distribute_with_objectarray(): + + input_array = as_tensor( + [ + [1, 2, 3], + [5, 6], + [10, 20, 30, 40], + [100], + ], + dtype=object, + ) + + def f(x: ObjectArray) -> ObjectArray: + n = len(x) + y = ObjectArray(n) + for i in range(n): + y[i] = sum(x[i]) + return y + + distributed_f = distribute(f, devices=["cpu", "cpu"]) + + recombined_result = distributed_f(input_array) + expected_result = f(input_array) + + assert isinstance(recombined_result, ObjectArray) + assert isinstance(expected_result, ObjectArray) + assert len(recombined_result) == len(expected_result) + assert np.all(expected_result.numpy() == recombined_result.numpy())