Skip to content

Commit

Permalink
Change RollOut for extrapolation and document
Browse files Browse the repository at this point in the history
  • Loading branch information
gallego-posada committed Feb 14, 2025
1 parent ac5a29c commit 487f305
Showing 1 changed file with 27 additions and 2 deletions.
29 changes: 27 additions & 2 deletions src/cooper/optim/constrained_optimizers/extrapolation_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def dual_extrapolation_step(self) -> None:
multiplier.post_step_()

def roll(self, compute_cmp_state_kwargs: Optional[dict] = None) -> RollOut:
"""Performs a full update step on the primal and dual variables.
r"""Performs a full update step on the primal and dual variables.
Note that the forward and backward computations are carried out
*twice*, as part of the
Expand All @@ -142,6 +142,28 @@ def roll(self, compute_cmp_state_kwargs: Optional[dict] = None) -> RollOut:
compute_cmp_state_kwargs: Keyword arguments to pass to the
:py:meth:`~cooper.ConstrainedMinimizationProblem.compute_cmp_state()`
method.
Returns:
:py:class:`~cooper.optim.optimizer.RollOut`: A named tuple containing the
following objects:
- loss (:py:class:`~torch.Tensor`):
The loss value computed after the extrapolation step :math:`f(\vx_{t})`.
- cmp_state (:py:class:`~cooper.CMPState`):
The CMP state at :math:`\vx_{t}`.
- primal_lagrangian_store (:py:class:`~cooper.LagrangianStore`):
The primal Lagrangian store at :math:`\vx_{t}`,
:math:`\vlambda_{t}` and :math:`\vmu_{t}`.
- dual_lagrangian_store (:py:class:`~cooper.LagrangianStore`):
The dual Lagrangian store at :math:`\vx_{t}`, :math:`\vlambda_t` and
:math:`\vmu_t`.
.. admonition::
:class: note
The `RollOut` for this scheme returns the loss and `CMPState` values at the
original point :math:`(\vx_t, \vlambda_t)`, *before* any of the updates are
performed.
"""
if compute_cmp_state_kwargs is None:
compute_cmp_state_kwargs = {}
Expand All @@ -152,6 +174,9 @@ def roll(self, compute_cmp_state_kwargs: Optional[dict] = None) -> RollOut:
primal_lagrangian_store = cmp_state.compute_primal_lagrangian()
dual_lagrangian_store = cmp_state.compute_dual_lagrangian()

if call_extrapolation:
roll_out = RollOut(cmp_state.loss, cmp_state, primal_lagrangian_store, dual_lagrangian_store)

primal_lagrangian_store.backward()
dual_lagrangian_store.backward()

Expand All @@ -162,4 +187,4 @@ def roll(self, compute_cmp_state_kwargs: Optional[dict] = None) -> RollOut:
self.primal_step()
self.dual_step()

return RollOut(cmp_state.loss, cmp_state, primal_lagrangian_store, dual_lagrangian_store)
return roll_out

0 comments on commit 487f305

Please sign in to comment.