Skip to content

Some vmap and HPO related features and bug fixes #226

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions src/evox/operators/selection/non_dominate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Tuple

import torch

from evox.core import compile
from evox.utils import lexsort, register_vmap_op


Expand Down Expand Up @@ -61,7 +60,7 @@ def update_dc_and_rank(
return rank, dominate_count


_compiled_update_dc_and_rank = torch.compile(update_dc_and_rank, fullgraph=True)
_compiled_update_dc_and_rank = compile(update_dc_and_rank, fullgraph=True)


def _igr_fake(
Expand All @@ -78,28 +77,28 @@ def _igr_fake_vmap(
dominate_count: torch.Tensor,
rank: torch.Tensor,
pareto_front: torch.Tensor,
) -> Tuple[torch.Tensor, int]:
return rank.new_empty(dominate_count.size()), 0
) -> torch.Tensor:
return rank.new_empty(dominate_count.size())


def _vmap_iterative_get_ranks(
dominate_relation_matrix: torch.Tensor,
dominate_count: torch.Tensor,
rank: torch.Tensor,
pareto_front: torch.Tensor,
) -> Tuple[torch.Tensor, int]:
) -> torch.Tensor:
current_rank = 0
while pareto_front.any():
rank, dominate_count = _compiled_update_dc_and_rank(
rank, dominate_count = (_compiled_update_dc_and_rank if torch.compiler.is_compiling() else update_dc_and_rank)(
dominate_relation_matrix, dominate_count, pareto_front, rank, current_rank
)
current_rank += 1
new_pareto_front = dominate_count == 0
pareto_front = torch.where(pareto_front.any(dim=1, keepdim=True), new_pareto_front, pareto_front)
return rank, 0
pareto_front = torch.where(pareto_front.any(dim=-1, keepdim=True), new_pareto_front, pareto_front)
return rank


@register_vmap_op(fake_fn=_igr_fake, vmap_fn=_vmap_iterative_get_ranks, fake_vmap_fn=_igr_fake_vmap)
@register_vmap_op(fake_fn=_igr_fake, vmap_fn=_vmap_iterative_get_ranks, fake_vmap_fn=_igr_fake_vmap, max_vmap_level=2)
def _iterative_get_ranks(
dominate_relation_matrix: torch.Tensor,
dominate_count: torch.Tensor,
Expand All @@ -108,7 +107,7 @@ def _iterative_get_ranks(
) -> torch.Tensor:
current_rank = 0
while pareto_front.any():
rank, dominate_count = _compiled_update_dc_and_rank(
rank, dominate_count = (_compiled_update_dc_and_rank if torch.compiler.is_compiling() else update_dc_and_rank)(
dominate_relation_matrix, dominate_count, pareto_front, rank, current_rank
)
current_rank += 1
Expand Down Expand Up @@ -166,7 +165,6 @@ def crowding_distance(costs: torch.Tensor, mask: torch.Tensor):
inverted_mask = inverted_mask.unsqueeze(1).expand(-1, costs.size(1)).to(costs.dtype)

rank = lexsort([costs, inverted_mask], dim=0)
# TODO: num_valid_elem preventing vmap
costs = torch.gather(costs, dim=0, index=rank)
distance_range = costs[num_valid_elem - 1] - costs[0]
distance = torch.empty(costs.size(), device=costs.device)
Expand Down
227 changes: 192 additions & 35 deletions src/evox/problems/hpo_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,46 @@
import copy
import weakref
from abc import ABC
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple

import torch
from torch import nn

from evox.core import Monitor, Mutable, Problem, Workflow, use_state, vmap
from evox.core import Monitor, Mutable, Problem, Workflow, compile, use_state, vmap


def _vmap_vmap_mean_fit_aggregation(info, in_dims, fit: torch.Tensor) -> Tuple[torch.Tensor, int]:
return torch.mean(fit.movedim(in_dims[0], 0), dim=0, keepdim=True), 0


@torch.library.custom_op("evox::_hpo_vmap_mean_fit_aggregation", mutates_args=())
def _vmap_mean_fit_aggregation(fit: torch.Tensor) -> torch.Tensor:
return fit.clone()


_vmap_mean_fit_aggregation.register_fake(lambda f: f.new_empty(f.size()))
_vmap_mean_fit_aggregation.register_vmap(_vmap_vmap_mean_fit_aggregation)


@torch.library.custom_op("evox::_hpo_mean_fit_aggregation", mutates_args=())
def _mean_fit_aggregation(fit: torch.Tensor) -> torch.Tensor:
return fit.clone()


_mean_fit_aggregation.register_fake(lambda f: f.new_empty(f.size()))
_mean_fit_aggregation.register_vmap(lambda info, in_dims, fit: (_vmap_mean_fit_aggregation(fit.movedim(in_dims[0], 0)), 0))


class HPOMonitor(Monitor, ABC):
"""The base class for hyper parameter optimization (HPO) monitors used in `HPOProblem.workflow.monitor`."""

def __init__(self):
def __init__(
self,
num_repeats: int = 1,
fit_aggregation: Optional[Callable[[torch.Tensor, int], torch.Tensor]] = _mean_fit_aggregation,
):
super().__init__()
self.num_repeats = num_repeats
self.fit_aggregation = fit_aggregation

def tell_fitness(self) -> torch.Tensor:
"""Get the best fitness found so far in the optimization process that this monitor is monitoring.
Expand All @@ -25,14 +53,19 @@ def tell_fitness(self) -> torch.Tensor:
class HPOFitnessMonitor(HPOMonitor):
"""The monitor for hyper parameter optimization (HPO) that records the best fitness found so far in the optimization process."""

def __init__(self, multi_obj_metric: Optional[Callable] = None):
def __init__(
self,
num_repeats: int = 1,
fit_aggregation: Optional[Callable[[torch.Tensor, int], torch.Tensor]] = _mean_fit_aggregation,
multi_obj_metric: Optional[Callable] = None,
):
"""
Initialize the HPO fitness monitor.

:param multi_obj_metric: The metric function to use for multi-objective optimization, unused in single-objective optimization.
Currently we only support "IGD" or "HV" for multi-objective optimization. Defaults to `None`.
"""
super().__init__()
super().__init__(num_repeats, fit_aggregation)
assert multi_obj_metric is None or callable(multi_obj_metric), (
f"Expect `multi_obj_metric` to be `None` or callable, got {multi_obj_metric}"
)
Expand All @@ -46,7 +79,7 @@ def pre_tell(self, fitness: torch.Tensor):

:raises AssertionError: If the dimensionality of the fitness tensor is not 1 or 2.
"""
assert 1 <= fitness.ndim <= 2
fitness = self.fit_aggregation(fitness) if self.num_repeats > 1 else fitness
if fitness.ndim == 1:
# single-objective
self.best_fitness = torch.min(torch.min(fitness), self.best_fitness)
Expand Down Expand Up @@ -74,55 +107,156 @@ def get_sub_state(state: Dict[str, Any], name: str):
return state


class HPOData(NamedTuple):
workflow_step: Callable[[Dict[str, torch.Tensor]], Tuple[Dict[str, torch.Tensor]]] # workflow_step
compiled_workflow_step: Callable[[Dict[str, torch.Tensor]], Tuple[Dict[str, torch.Tensor]]] # compiled_workflow_step
state_keys: List[str] # state_keys or param_keys
buffer_keys: Optional[List[str]] # optional buffer_keys


__hpo_data__: Dict[int, HPOData] = {}


def _fake_hpo_evaluate_loop(id: int, iterations: int, state_values: List[torch.Tensor]) -> List[torch.Tensor]:
return [v.new_empty(v.size()) for v in state_values]


@torch.library.custom_op("evox::_hpo_evaluate_loop", mutates_args=())
def _hpo_evaluate_loop(id: int, iterations: int, state_values: List[torch.Tensor]) -> List[torch.Tensor]:
global __hpo_data__
workflow_step, compiled_workflow_step, state_keys, buffer_keys = __hpo_data__[id]
if buffer_keys is None:
state = {k: v.clone() for k, v in zip(state_keys, state_values)}
for _ in range(iterations):
if torch.compiler.is_compiling():
state = compiled_workflow_step(state)
else:
state = workflow_step(state)
return [state[k] for k in state_keys]
else:
param_keys, buffer_keys = state_keys, buffer_keys
params = {k: v.clone() for k, v in zip(param_keys, state_values)}
buffers = {k: v.clone() for k, v in zip(buffer_keys, state_values[len(param_keys) :])}
for _ in range(iterations):
if torch.compiler.is_compiling():
params, buffers = compiled_workflow_step(params, buffers)
else:
params, buffers = workflow_step(params, buffers)
return [params[k] for k in param_keys] + [buffers[k] for k in buffer_keys]


_hpo_evaluate_loop.register_fake(_fake_hpo_evaluate_loop)


class HPOProblemWrapper(Problem):
"""The problem for hyper parameter optimization (HPO).

## Usage
```
## Example
```python
algo = SomeAlgorithm(...)
prob = SomeProblem(...)
monitor = HPOFitnessMonitor()
workflow = StdWorkflow(algo, prob, monitor=monitor)
hpo_prob = HPOProblemWrapper(iterations=..., num_instances=...)
params = HPOProblemWrapper.extract_parameters(hpo_prob.init_state)
params = hpo_prob.get_init_params()
# alter `params` ...
hpo_prob.evaluate(params) # execute the evaluation
# ...
```
"""

def __init__(self, iterations: int, num_instances: int, workflow: Workflow, copy_init_state: bool = True):
def __init__(
self,
iterations: int,
num_instances: int,
workflow: Workflow,
num_repeats: int = 1,
copy_init_state: bool = False,
):
"""Initialize the HPO problem wrapper.

:param iterations: The number of iterations to be executed in the optimization process.
:param num_instances: The number of instances to be executed in parallel in the optimization process.
:param num_instances: The number of instances to be executed in parallel in the optimization process, i.e., the population size of the outer algorithm.
:param workflow: The workflow to be used in the optimization process. Must be wrapped by `core.jit_class`.
:param num_repeats: The number of times to repeat the evaluation process for each instance. Defaults to 1.
:param copy_init_state: Whether to copy the initial state of the workflow for each evaluation. Defaults to `True`. If your workflow contains operations that IN-PLACE modify the tensor(s) in initial state, this should be set to `True`. Otherwise, you can set it to `False` to save memory.
"""
super().__init__()
assert iterations > 0, f"`iterations` should be greater than 0, got {iterations}"
assert num_instances > 0, f"`num_instances` should be greater than 0, got {num_instances}"
self.iterations = iterations
self.num_instances = num_instances
self.num_repeats = num_repeats
self.copy_init_state = copy_init_state
# check monitor
monitor = workflow.monitor
assert isinstance(monitor, HPOMonitor), f"Expect workflow monitor to be `HPOMonitor`, got {type(monitor)}"
self.hpo_monitor = monitor
monitor.num_repeats = num_repeats

# compile workflow steps
state_step = use_state(workflow.step)

# JIT workflow step
vmap_state_step = vmap(state_step, randomness="same")
def repeat_state_step(params: Dict[str, torch.Tensor], buffers: Dict[str, torch.Tensor]):
state = {**params, **buffers}
state = state_step(state)
return {k: state[k] for k in params.keys()}, {k: state[k] for k in buffers.keys()}

vmap_state_step = (
torch.vmap(
torch.vmap(repeat_state_step, randomness="same"),
randomness="different",
in_dims=(None, 0),
out_dims=(None, 0),
)
if num_repeats > 1
else torch.vmap(state_step, randomness="same")
)
self._init_params, self._init_buffers = torch.func.stack_module_state([workflow] * self.num_instances)
self._workflow_step_ = torch.compile(vmap_state_step)
if num_repeats > 1:
self._init_buffers = {k: torch.stack([v] * num_repeats) for k, v in self._init_buffers.items()}
self._workflow_step_ = vmap_state_step
self._compiled_workflow_step_ = compile(vmap_state_step, fullgraph=True)

if type(workflow).init_step == Workflow.init_step:
# if no init step
print("No init step")
self._workflow_init_step_ = self._workflow_step_
self._compiled_init_step_ = self._compiled_workflow_step_
else:
# otherwise, JIT workflow init step
# otherwise, compile workflow init step
state_init_step = use_state(workflow.init_step)
vmap_state_init_step = vmap(state_init_step, randomness="same")
self._workflow_init_step_ = torch.compile(vmap_state_init_step)

def repeat_state_init_step(params: Dict[str, torch.Tensor], buffers: Dict[str, torch.Tensor]):
state = {**params, **buffers}
state = state_step(state)
return {k: state[k] for k in params.keys()}, {k: state[k] for k in buffers.keys()}

vmap_state_init_step = (
torch.vmap(
torch.vmap(repeat_state_init_step, randomness="same"),
randomness="different",
in_dims=(None, 0),
out_dims=(None, 0),
)
if num_repeats > 1
else torch.vmap(state_init_step, randomness="same")
)
self._workflow_init_step_ = vmap_state_init_step
self._compiled_workflow_init_step_ = compile(vmap_state_init_step, fullgraph=True)

self.state_keys = (list(self._init_params.keys()), list(self._init_buffers.keys()))
if self.num_repeats == 1:
self.state_keys = sum(self.state_keys, [])
global __hpo_data__
__hpo_data__[id(self)] = HPOData(
workflow_step=self._workflow_step_,
compiled_workflow_step=self._compiled_workflow_step_,
state_keys=self.state_keys if self.num_repeats == 1 else self.state_keys[0],
buffer_keys=None if self.num_repeats == 1 else self.state_keys[1],
)
self._id_ = id(self)
weakref.finalize(self, __hpo_data__.pop, id(self), None)

self._stateful_tell_fitness = use_state(monitor.tell_fitness)

def evaluate(self, hyper_parameters: Dict[str, nn.Parameter]):
"""
Expand All @@ -133,25 +267,48 @@ def evaluate(self, hyper_parameters: Dict[str, nn.Parameter]):
:return: The final fitness of the hyper parameters.
"""
# hyper parameters check
for k, _v in hyper_parameters.items():
for k, _ in hyper_parameters.items():
assert k in self._init_params, (
f"`{k}` should be a hyperparameter of the workflow, available keys are {self.get_params_keys()}"
)

state = self._init_params | self._init_buffers
if self.copy_init_state:
state = copy.deepcopy(state)

# Override with the given hyper parameters
state.update(hyper_parameters)
# run the workflow
state = self._workflow_init_step_(state)
for _ in range(self.iterations - 1):
state = self._workflow_step_(state)
# get final fitness
monitor_state = get_sub_state(state, "monitor")
_monitor_state, fit = vmap(use_state(self.hpo_monitor.tell_fitness), randomness="same")(monitor_state)
return fit
if self.num_repeats > 1:
if self.copy_init_state:
params = {k: v.clone() for k, v in self._init_params.items()}
buffers = {k: v.clone() for k, v in self._init_buffers.items()}
else:
params = self._init_params
buffers = self._init_buffers
params = {**self._init_params, **hyper_parameters}
# run the workflow
if torch.compiler.is_compiling():
params, buffers = self._compiled_workflow_init_step_(params, buffers)
else:
params, buffers = self._workflow_init_step_(params, buffers)
state_values = [params[k] for k in self.state_keys[0]] + [buffers[k] for k in self.state_keys[1]]
state_values = _hpo_evaluate_loop(self._id_, self.iterations - 1, state_values)
params = {k: v for k, v in zip(self.state_keys[0], state_values)}
buffers = {k: v for k, v in zip(self.state_keys[1], state_values[len(params) :])}
monitor_state = get_sub_state(buffers, "monitor")
_, fit = vmap(torch.vmap(self._stateful_tell_fitness))(monitor_state)
return fit[0]
else:
state: Dict[str, torch.Tensor] = {**self._init_params, **self._init_buffers}
if self.copy_init_state:
state = {k: v.clone() for k, v in state.items()}
# Override with the given hyper parameters
state.update(hyper_parameters)
# run the workflow
if torch.compiler.is_compiling():
state = self._compiled_workflow_init_step_(state)
else:
state = self._workflow_init_step_(state)
state_values = [state[k] for k in self.state_keys]
state_values = _hpo_evaluate_loop(self._id_, self.iterations - 1, state_values)
state = {k: v for k, v in zip(self.state_keys, state_values)}
monitor_state = get_sub_state(state, "monitor")
_, fit = vmap(self._stateful_tell_fitness)(monitor_state)
return fit

def get_init_params(self):
"""Return the initial hyper-parameters dictionary of the underlying workflow."""
Expand Down
Loading
Loading