Skip to content

Commit

Permalink
Progress on optim docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
merajhashemi committed Feb 10, 2025
1 parent 75760d5 commit 87506c3
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,18 @@ class ConstrainedOptimizer(CooperOptimizer, abc.ABC):
:py:class:`~cooper.optim.UnconstrainedOptimizer` class.
Args:
cmp: The constrained minimization problem to be optimized. Providing the CMP
as an argument for the constructor allows the optimizer to call the
:py:meth:`~cooper.cmp.ConstrainedMinimizationProblem.compute_cmp_state`
method within the :py:meth:`~cooper.optim.cooper_optimizer.CooperOptimizer.roll`
method. Additionally, in the case of a constrained optimizer, the CMP
enables access to the multipliers'
:py:meth:`~cooper.multipliers.Multiplier.post_step_` method which must be
called after the multiplier update.
primal_optimizers: Optimizer(s) for the primal variables (e.g. the weights of
a model). The primal parameters can be partitioned into multiple optimizers,
in this case ``primal_optimizers`` accepts a list of
:py:class:`torch.optim.Optimizer`\s.
dual_optimizers: Optimizer(s) for the dual variables (e.g. the Lagrange
multipliers associated with the constraints). A sequence of
:py:class:`torch.optim.Optimizer`\s can be passed to handle the case of
Expand Down
32 changes: 29 additions & 3 deletions src/cooper/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@


class CooperOptimizerState(TypedDict):
"""Stores the state of a :py:class:`~cooper.optim.cooper_optimizer.CooperOptimizer`."""

primal_optimizer_states: list[dict]
dual_optimizer_states: Optional[list[dict]]

Expand All @@ -30,10 +32,29 @@ class RollOut(NamedTuple):


class CooperOptimizer(abc.ABC):
# TODO(gallego-posada): Write docstring
r"""Base class for :py:class:`~cooper.optim.constrained_optimizer.ConstrainedOptimizer`
and :py:class:`~cooper.optim.UnconstrainedOptimizer`\s.
# TODO(gallego-posada): Why do we need to pass CMP here? Clarify and document
# What is the reason beyond just being able to call compute_cmp_state()?
Args:
cmp: The constrained minimization problem to be optimized. Providing the CMP
as an argument for the constructor allows the optimizer to call the
:py:meth:`~cooper.cmp.ConstrainedMinimizationProblem.compute_cmp_state`
method within the :py:meth:`~cooper.optim.cooper_optimizer.CooperOptimizer.roll`
method. Additionally, in the case of a constrained optimizer, the CMP
enables access to the multipliers'
:py:meth:`~cooper.multipliers.Multiplier.post_step_` method which must be
called after the multiplier update.
primal_optimizers: Optimizer(s) for the primal variables (e.g. the weights of
a model). The primal parameters can be partitioned into multiple optimizers,
in this case ``primal_optimizers`` accepts a list of
:py:class:`torch.optim.Optimizer`\s.
dual_optimizers: Optimizer(s) for the dual variables (e.g. the Lagrange
multipliers associated with the constraints). A sequence of
:py:class:`torch.optim.Optimizer`\s can be passed to handle the case of
several :py:class:`~cooper.constraints.Constraint`\s.
"""

def __init__(
self,
Expand Down Expand Up @@ -65,6 +86,11 @@ def primal_step(self) -> None:
primal_optimizer.step()

def state_dict(self) -> CooperOptimizerState:
r"""Return the state of the optimizer as a
:py:class:`~cooper.optim.cooper_optimizer.CooperOptimizerState`. This method
relies on the internal :py:meth:`~torch.optim.Optimizer.state_dict` method of
the corresponding primal or dual optimizers.
"""
primal_optimizer_states = [optimizer.state_dict() for optimizer in self.primal_optimizers]

dual_optimizer_states = None
Expand Down
8 changes: 6 additions & 2 deletions src/cooper/optim/unconstrained_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@ class UnconstrainedOptimizer(CooperOptimizer):
:py:class:`~cooper.optim.constrained_optimizers.ConstrainedOptimizer`\s.
Args:
cmp: The constrained minimization problem to optimize.
cmp: The constrained minimization problem to be optimized. Providing the CMP
as an argument for the constructor allows the optimizer to call the
:py:meth:`~cooper.cmp.ConstrainedMinimizationProblem.compute_cmp_state`
method within the :py:meth:`~cooper.optim.cooper_optimizer.CooperOptimizer.roll`
method.
primal_optimizers: Optimizer(s) for the primal variables (e.g. the weights of
a model). The primal parameters can be partitioned into multiple optimizers,
in this case ``primal_optimizers`` accepts a sequence of
in this case ``primal_optimizers`` accepts a list of
:py:class:`torch.optim.Optimizer`\s.
"""

Expand Down

0 comments on commit 87506c3

Please sign in to comment.