Skip to content

Commit

Permalink
Fix manual tests for extrapolation updates
Browse files Browse the repository at this point in the history
  • Loading branch information
merajhashemi committed Feb 14, 2025
1 parent eed3abd commit e8c0bf7
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 40 deletions.
15 changes: 10 additions & 5 deletions testing/cooper_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,15 @@ def _compute_constraint_states(
lhs_sur: Optional[torch.Tensor],
is_sampled: bool,
observed_constraint_ratio: float,
seed: Optional[int],
) -> cooper.ConstraintState:
"""Computes the constraint state."""
num_constraints = rhs.numel()
strict_violation = torch.matmul(lhs, x) - rhs

if seed is not None:
self.generator.manual_seed(seed)

strict_constraint_features = None
if is_sampled:
strict_constraint_features = torch.randperm(num_constraints, generator=self.generator, device=self.device)
Expand All @@ -180,30 +185,30 @@ def _compute_constraint_states(
strict_constraint_features=strict_constraint_features,
)

def compute_violations(self, x: torch.Tensor) -> cooper.CMPState:
def compute_violations(self, x: torch.Tensor, seed: Optional[int]) -> cooper.CMPState:
"""Computes the constraint violations for the given parameters."""
observed_constraints = {}

if self.has_ineq_constraint:
ineq_state = self._compute_constraint_states(
x, self.A, self.b, self.A_sur, self.is_ineq_sampled, self.ineq_observed_constraint_ratio
x, self.A, self.b, self.A_sur, self.is_ineq_sampled, self.ineq_observed_constraint_ratio, seed
)
observed_constraints[self.ineq_constraints] = ineq_state

if self.has_eq_constraint:
eq_state = self._compute_constraint_states(
x, self.C, self.d, self.C_sur, self.is_eq_sampled, self.eq_observed_constraint_ratio
x, self.C, self.d, self.C_sur, self.is_eq_sampled, self.eq_observed_constraint_ratio, seed
)
observed_constraints[self.eq_constraints] = eq_state

return cooper.CMPState(observed_constraints=observed_constraints)

def compute_cmp_state(self, x: torch.Tensor) -> cooper.CMPState:
def compute_cmp_state(self, x: torch.Tensor, seed: Optional[int] = None) -> cooper.CMPState:
"""Computes the state of the CMP at the current value of the primal parameters
by evaluating the loss and constraints.
"""
loss = torch.sum(x**2)
violation_state = self.compute_violations(x)
violation_state = self.compute_violations(x, seed)
cmp_state = cooper.CMPState(loss=loss, observed_constraints=violation_state.observed_constraints)
return cmp_state

Expand Down
70 changes: 35 additions & 35 deletions tests/pipeline/test_manual.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,7 @@ def _setup_cmp(

self.expects_multiplier = formulation_type.expects_multiplier
self.expects_penalty_coefficient = formulation_type.expects_penalty_coefficient
self.has_penalty_updater = formulation_type in {
cooper.formulations.QuadraticPenalty,
cooper.formulations.AugmentedLagrangian,
}
self.use_penalty_updater = formulation_type != cooper.formulations.Lagrangian

self.is_indexed_multiplier = multiplier_type == cooper.multipliers.IndexedMultiplier
self.num_variables = num_variables
Expand All @@ -57,9 +54,6 @@ def test_manual_step(self, extrapolation, alternation_type):
The manual implementation assumes Stochastic Gradient Descent (SGD) is used for both the primal
and dual optimizers.
"""
if alternation_type == testing.AlternationType.PRIMAL_DUAL and self.is_indexed_multiplier:
pytest.skip("Cannot test IndexedMultiplier with PRIMAL_DUAL alternation.")

x = torch.nn.Parameter(torch.ones(self.num_variables, device=self.device))
optimizer_class = cooper.optim.ExtraSGD if extrapolation else torch.optim.SGD
primal_optimizers = optimizer_class([x], lr=PRIMAL_LR)
Expand All @@ -74,15 +68,11 @@ def test_manual_step(self, extrapolation, alternation_type):
)

penalty_updater = None
if self.has_penalty_updater:
if self.use_penalty_updater:
penalty_updater = cooper.penalty_coefficients.MultiplicativePenaltyCoefficientUpdater(
growth_factor=PENALTY_GROWTH_FACTOR, violation_tolerance=PENALTY_VIOLATION_TOLERANCE
)

roll_kwargs = {"compute_cmp_state_kwargs": {"x": x}}
if alternation_type == testing.AlternationType.PRIMAL_DUAL:
roll_kwargs["compute_violations_kwargs"] = {"x": x}

manual_x = torch.ones(self.num_variables, device=self.device)
manual_multiplier = None
if self.expects_multiplier:
Expand All @@ -92,27 +82,30 @@ def test_manual_step(self, extrapolation, alternation_type):
manual_penalty_coeff = torch.ones(self.num_constraints, device=self.device)

# ----------------------- Iterations -----------------------
for _ in range(2):
for step in range(2):
roll_kwargs = {"compute_cmp_state_kwargs": {"x": x, "seed": step}}
if alternation_type == testing.AlternationType.PRIMAL_DUAL:
roll_kwargs["compute_violations_kwargs"] = {"x": x, "seed": step}
roll_out = cooper_optimizer.roll(**roll_kwargs)
if self.has_penalty_updater:
if self.use_penalty_updater:
penalty_updater.step(roll_out.cmp_state.observed_constraints)

observed_multipliers = None
primal_observed_multipliers = None
dual_observed_multipliers = None
if self.expects_multiplier:
if alternation_type == testing.AlternationType.PRIMAL_DUAL:
observed_multipliers = torch.cat(list(roll_out.dual_lagrangian_store.observed_multiplier_values()))
else:
observed_multipliers = torch.cat(
list(roll_out.primal_lagrangian_store.observed_multiplier_values())
)
primal_observed_multipliers = torch.cat(
list(roll_out.primal_lagrangian_store.observed_multiplier_values())
)
dual_observed_multipliers = torch.cat(list(roll_out.dual_lagrangian_store.observed_multiplier_values()))

# The CMP has only one constraint, so we can use the first element
features = next(iter(roll_out.cmp_state.observed_constraint_features()))
if features is None:
features = torch.arange(self.num_constraints, device=self.device, dtype=torch.long)

strict_features = next(iter(roll_out.cmp_state.observed_strict_constraint_features()))
if strict_features is None:
strict_features = torch.arange(self.num_constraints, device=self.device, dtype=torch.long)
strict_features = features

manual_x_prev = manual_x.clone()
# Manual step
Expand All @@ -121,7 +114,8 @@ def test_manual_step(self, extrapolation, alternation_type):
manual_multiplier,
manual_primal_lagrangian,
manual_dual_lagrangian,
manual_observed_multipliers,
manual_primal_observed_multipliers,
manual_dual_observed_multipliers,
) = self.manual_roll(
manual_x,
manual_multiplier,
Expand All @@ -132,7 +126,7 @@ def test_manual_step(self, extrapolation, alternation_type):
extrapolation,
)

if self.has_penalty_updater:
if self.use_penalty_updater:
# Update penalty coefficients for the Quadratic Penalty formulation
self._update_penalty_coefficients(
manual_x, manual_x_prev, strict_features, alternation_type, manual_penalty_coeff
Expand All @@ -142,7 +136,8 @@ def test_manual_step(self, extrapolation, alternation_type):
assert torch.allclose(x, manual_x)
assert torch.allclose(roll_out.primal_lagrangian_store.lagrangian, manual_primal_lagrangian)
if self.expects_multiplier:
assert torch.allclose(observed_multipliers, manual_observed_multipliers[features])
assert torch.allclose(primal_observed_multipliers, manual_primal_observed_multipliers[features])
assert torch.allclose(dual_observed_multipliers, manual_dual_observed_multipliers[strict_features])
assert torch.allclose(roll_out.dual_lagrangian_store.lagrangian, manual_dual_lagrangian)

def _violation(self, x, strict=False):
Expand Down Expand Up @@ -239,49 +234,54 @@ def _simultaneous_roll(self, x, multiplier, features, strict_features, penalty_c
self._dual_step(x, multiplier, strict_features),
)

return x, multiplier, primal_lagrangian, dual_lagrangian, observed_multipliers
# primal and dual observed multipliers are the same
return x, multiplier, primal_lagrangian, dual_lagrangian, observed_multipliers, observed_multipliers

def _dual_primal_roll(self, x, multiplier, features, strict_features, penalty_coeff):
# Dual step
dual_observed_multipliers = multiplier.clone()
dual_lagrangian = self._dual_lagrangian(x, multiplier, strict_features)
multiplier = self._dual_step(x, multiplier, strict_features)

# Primal step
primal_lagrangian = self._primal_lagrangian(x, multiplier, features, penalty_coeff)
observed_multipliers = multiplier.clone()
primal_observed_multipliers = multiplier.clone()
x = self._primal_step(x, multiplier, features, penalty_coeff)

return x, multiplier, primal_lagrangian, dual_lagrangian, observed_multipliers
return x, multiplier, primal_lagrangian, dual_lagrangian, primal_observed_multipliers, dual_observed_multipliers

def _primal_dual_roll(self, x, multiplier, features, strict_features, penalty_coeff):
# Primal step
primal_lagrangian = self._primal_lagrangian(x, multiplier, features, penalty_coeff)
observed_multipliers = multiplier.clone()
primal_observed_multipliers = multiplier.clone()
x = self._primal_step(x, multiplier, features, penalty_coeff)

# Dual step
dual_observed_multipliers = multiplier.clone()
dual_lagrangian = self._dual_lagrangian(x, multiplier, strict_features)
multiplier = self._dual_step(x, multiplier, strict_features)

return x, multiplier, primal_lagrangian, dual_lagrangian, observed_multipliers
return x, multiplier, primal_lagrangian, dual_lagrangian, primal_observed_multipliers, dual_observed_multipliers

def _extragradient_roll(self, x, multiplier, features, strict_features, penalty_coeff):
x_copy = x.clone() # x_t
multiplier_copy = multiplier.clone()

# Compute the Lagrangians
primal_lagrangian = self._primal_lagrangian(x, multiplier, features, penalty_coeff)
observed_multipliers = multiplier.clone()
dual_lagrangian = self._dual_lagrangian(x, multiplier, strict_features)

# Extrapolation step
x = self._primal_step(x_copy, multiplier_copy, features, penalty_coeff) # x_{t+1/2}
multiplier = self._dual_step(x_copy, multiplier_copy.clone(), strict_features)

# Update step
primal_lagrangian = self._primal_lagrangian(x, multiplier, features, penalty_coeff)
observed_multipliers = multiplier.clone()
dual_lagrangian = self._dual_lagrangian(x, multiplier, strict_features)

x_grad = self._primal_gradient(x, multiplier, features, penalty_coeff)
x, multiplier = x_copy - PRIMAL_LR * x_grad, self._dual_step(x, multiplier_copy, strict_features)

return x, multiplier, primal_lagrangian, dual_lagrangian, observed_multipliers
# primal and dual observed multipliers are the same
return x, multiplier, primal_lagrangian, dual_lagrangian, observed_multipliers, observed_multipliers

@torch.inference_mode()
def manual_roll(self, x, multiplier, features, strict_features, alternation_type, penalty_coeff, extrapolation):
Expand Down

0 comments on commit e8c0bf7

Please sign in to comment.