diff --git a/diffcp/cone_program.py b/diffcp/cone_program.py index f9b6f9a..40e0ade 100644 --- a/diffcp/cone_program.py +++ b/diffcp/cone_program.py @@ -10,6 +10,45 @@ import diffcp.cones as cone_lib +def permute_psd_rows(A: sparse.csc_matrix, b: np.ndarray, n: int, row_offset: int) -> sparse.csc_matrix: + """ + Permutes rows of a sparse CSC constraint matrix A to switch from lower + triangular order (SCS) to upper triangular order (Clarabel) for a PSD constraint. + + Args: + A (csc_matrix): Constraint matrix in CSC format. + n (int): Size of the PSD constraint matrix (n x n). + row_offset (int): Row index where the PSD block starts. + + Returns: + csc_matrix: New CSC matrix with permuted rows. + """ + + tril_rows, tril_cols = np.tril_indices(n) + triu_rows, triu_cols = np.triu_indices(n) + + # Compute the permutation mapping + tril_multi_index = np.ravel_multi_index((tril_cols, tril_rows), (n, n)) + triu_multi_index = np.ravel_multi_index((triu_cols, triu_rows), (n, n)) + postshuffle_from_preshuffle_perm = np.argsort(tril_multi_index) + row_offset + preshuffle_from_postshuffle_perm = np.argsort(triu_multi_index) + row_offset + n_rows = len(postshuffle_from_preshuffle_perm) + + # Apply row permutation + data, rows, cols = A.data, A.indices, A.indptr + new_rows = np.copy(rows) # Create a new row index array + # Identify affected rows + mask = (rows >= row_offset) & (rows < (row_offset + n_rows)) + new_rows[mask] = postshuffle_from_preshuffle_perm[rows[mask] - row_offset] + + new_A = sparse.csc_matrix((data, new_rows, cols), shape=A.shape) + + new_b = np.copy(b) + + new_b[row_offset:row_offset+n_rows] = new_b[preshuffle_from_postshuffle_perm] + + return new_A, new_b + def pi(z, cones): """Projection onto R^n x K^* x R_+ @@ -443,24 +482,33 @@ def solve_internal(A, b, c, cone_dict, solve_method=None, P = sparse.csc_matrix((c.size, c.size)) cones = [] + # Given the difference in convention between SCS (lower triangluar columnwise) + # and Clarabel (upper triangular columnwise), the rows of A may need to be permuted + start_row = 0 if "z" in cone_dict: v = cone_dict["z"] if v > 0: cones.append(clarabel.ZeroConeT(v)) + start_row += v if "f" in cone_dict: v = cone_dict["f"] if v > 0: cones.append(clarabel.ZeroConeT(v)) + start_row += v if "l" in cone_dict: v = cone_dict["l"] if v > 0: cones.append(clarabel.NonnegativeConeT(v)) + start_row += v if "q" in cone_dict: for v in cone_dict["q"]: cones.append(clarabel.SecondOrderConeT(v)) + start_row += v if "s" in cone_dict: for v in cone_dict["s"]: cones.append(clarabel.PSDTriangleConeT(v)) + A, b = permute_psd_rows(A, b, v, start_row) + start_row += v if "ep" in cone_dict: v = cone_dict["ep"] cones += [clarabel.ExponentialConeT()] * v diff --git a/pyproject.toml b/pyproject.toml index 1cb177e..da92ba3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,7 @@ requires-python = ">= 3.10" readme = "README.md" dependencies = [ + "cvxpy>=1.6.3", "threadpoolctl >= 1.1", ] urls = {Homepage = "https://github.com/cvxgrp/diffcp/"} @@ -31,7 +32,9 @@ test = [ "clarabel >= 0.5.1", "ecos >= 2.0.10", "scs >= 3.0.0", + "pytest>=8.3.5", ] [tool.scikit-build] wheel.expand-macos-universal-tags = true + diff --git a/tests/test_clarabel.py b/tests/test_clarabel.py index 822e5ce..c2f2046 100644 --- a/tests/test_clarabel.py +++ b/tests/test_clarabel.py @@ -2,6 +2,7 @@ import numpy as np import diffcp.cone_program as cone_prog +from diffcp.cones import unvec_symm import diffcp.utils as utils @@ -120,3 +121,20 @@ def test_expcone(): np.testing.assert_allclose(x_pert - x, dx, atol=1e-8) np.testing.assert_allclose(y_pert - y, dy, atol=1e-8) np.testing.assert_allclose(s_pert - s, ds, atol=1e-8) + +def test_psdcone(): + DIM = 5 + X = cp.Variable(shape=(DIM, DIM), PSD=True) + C = np.zeros((DIM, DIM)) + C[0, 0] = 1 + C[4, 4] = -1 + objective = cp.Minimize(cp.trace(C @ X)) + constraint = cp.trace(X) == 1 + problem = cp.Problem(objective, [constraint]) + A, b, c, cone_dims = utils.scs_data_from_cvxpy_problem(problem) + sol_vec, _, _, _, _ = cone_prog.solve_and_derivative(A, b, c, cone_dims, solve_method='Clarabel') + + sol = unvec_symm(sol_vec, DIM) + + assert np.abs(np.trace(sol) - 1.0) < 1e-6 + assert (np.linalg.eigvals(sol) >= -1e-6).all()