diff --git a/.gitignore b/.gitignore index 894a44c..3f69acf 100644 --- a/.gitignore +++ b/.gitignore @@ -102,3 +102,8 @@ venv.bak/ # mypy .mypy_cache/ + +# Python data files +*.npy +*.pkl +*.pickle diff --git a/fax/competitive/cmd/__init__.py b/fax/competitive/cmd/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/fax/competitive/cmd/cmd.py b/fax/competitive/cmd/cmd.py new file mode 100644 index 0000000..534254d --- /dev/null +++ b/fax/competitive/cmd/cmd.py @@ -0,0 +1,477 @@ +import collections +from fax import math +import jax +import jax.numpy as jnp +from jax import tree_util, jacfwd, random, grad, jvp +from jax.scipy.sparse import linalg +from jax.scipy import linalg as scipy_linalg +from functools import partial +from jax.config import config + +config.update("jax_enable_x64", True) +BregmanPotential = collections.namedtuple("BregmanPotential", ["DP", "DP_inv", "D2P", "inv_D2P"]) + + +# AugmentedDP = collections.namedtuple("AugmentedDP", ["DP_primal", "DP_eq", "DP_ineq"]) +# AugmentedDPinv = collections.namedtuple("AugmentedDPinv", ["DPinv_primal","DPinv_eq","DPinv_ineq"]) +# AugmentedD2P = collections.namedtuple("AugmentedD2P", ["D2P_primal", "D2P_eq", "D2P_ineq"]) +# AugmentedD2Pinv = collections.namedtuple("AugmentedD2Pinv", ["D2Pinv_primal","D2Pinv_eq","D2Pinv_ineq"]) + + +def id_func(x): + return lambda u: jnp.dot(jnp.identity(x.shape[0]), u) + + +def hvp(f, primals, tangents): + return jvp(grad(f), primals, tangents)[1] + + +def breg_bound(vec, lb=-1.0, ub=1.0, *args, **kwargs): + return jnp.sum((- vec + ub) * jnp.log(- vec + ub) + (vec - lb) * jnp.log(vec - lb)) + + +DP_bound = jax.grad(breg_bound, 0) + + +def DP_inv_bound(vec, lb=-1.0, ub=1.0): + return (ub * jnp.exp(vec) + lb) / (1 + jnp.exp(vec)) + + +def D2P_bound(vec, lb=-1.0, ub=1.0): + def out(u): + return jvp(lambda x: DP_bound(x, lb, ub), (vec,), (u,))[1] + + return out + + +def inv_D2P_bound(vec, lb=-1.0, ub=1.0): + if len(jnp.shape(vec)) <= 1: + def out(u): + return jnp.dot(jnp.diag(1 / ((1 / (ub - vec)) + (1 / (vec - lb)))), u) + else: + def out(u): + return (1 / ((1 / (ub - vec)) + (1 / (vec - lb)))) * u + return out + + +bound_breg = BregmanPotential(DP_bound, DP_inv_bound, D2P_bound, inv_D2P_bound) + + +def DP_hand(vec, nx): + temp = jnp.reshape(vec, (nx, nx)) + return (-jnp.linalg.inv(temp).T + temp).reshape(nx ** 2, 1) + + +def matrix_DP_pd(M): + return -jnp.linalg.slogdet(M)[1] + + +def vector_DP_pd(v): + return jnp.dot(v, jnp.log(v)) + + +def DP_pd(v): + m = len(jnp.shape(v)) + if m <= 1: + out = grad(lambda x: jnp.dot(x, jnp.log(x)))(v) + else: + out = grad(lambda M: -jnp.linalg.slogdet(M)[1])(v) + return out + + +def vector_DP_inv_pd(v): + return jnp.exp(v - jnp.ones_like(v)) + + +def DP_inv_pd(v): + m = len(jnp.shape(v)) + if m <= 1: + out = vector_DP_inv_pd(v) + else: + out = -scipy_linalg.inv(v).T + return out + + +def D2P_pd(v): + m = len(jnp.shape(v)) + if m <= 1: + def out(u): + return hvp(vector_DP_pd, (v,), (u,)) + else: + def out(u): + return hvp(matrix_DP_pd, (v,), (u,)) + return out + + +def inv_D2P_pd(v): + m = len(jnp.shape(v)) + if m <= 1: + def out(u): + return jnp.dot(jnp.diag(v), u) + else: + def out(u): + return jnp.dot(jnp.linalg.matrix_power(v, 2).T, u) + return out + +def D2P_l2(v): + return lambda x: x + + +def default_func(x, *args, **kwargs): + return None + + +# TODO: bake in step size to the bregman potential definitions instead of in cmd loop +default_breg = BregmanPotential(lambda v: jax.tree_map(lambda x: x, v), + lambda v: jax.tree_map(lambda x: x, v), D2P_l2, D2P_l2) + + +def make_pd_bregman(step_size=1e-5): + return BregmanPotential(DP_pd, DP_inv_pd, D2P_pd, inv_D2P_pd) + + +def make_bound_breg(lb=jnp.array([-1.0]), ub=jnp.array([1.0]), step_size = jnp.array([1e-4])): + + def breg_bound_internal(vec): + assert vec.shape == lb.shape, "lower bound shape does not match variable shape!" + assert vec.shape == ub.shape, "upper bound shape does not match variable shape!" + assert step_size.shape == vec.shape, "eta_ shape does not match variable shape!" + return jnp.sum( 1./step_size * ((- vec + ub) * jnp.log(- vec + ub) + (vec - lb) * jnp.log(vec - lb)), axis=-1) + + DP_bound_internal = jax.grad(breg_bound_internal) + + def DP_inv_bound_internal(vec): + vec = step_size * vec + return (ub * jnp.exp(vec) + lb) / (1 + jnp.exp(vec)) + + def D2P_bound_internal(vec): + return lambda u: jvp(DP_bound_internal, (vec,), (u,))[1] + + def inv_D2P_bound_internal(vec): + if len(jnp.shape(vec)) >= 1: + return lambda u: jnp.dot(jnp.diag( step_size * 1 / ((1 / (ub - vec)) + (1 / (vec - lb)))), u) + else: + return lambda u: ( step_size * 1 / ((1 / (ub - vec)) + (1 / (vec - lb)))) * u + + + return BregmanPotential(DP_bound_internal, DP_inv_bound_internal, + D2P_bound_internal, inv_D2P_bound_internal) + + + +# usage: hessian_xy((min_P,max_P))(max_P) +def make_mixed_jvp(f, first_args, second_args, opposite=False): + """Make a mixed jacobian-vector product function + Args: + f (callable): Binary callable with signature f(x,y) + first_args (numpy.ndarray): First arguments to f + second_args (numpy.ndarray): Second arguments to f + opposite (bool, optional): Take Dyx if False, Dxy if True. Defaults to + False. + Returns: + callable: Unary callable 'jvp(v)' taking a numpy.ndarray as input. + """ + if opposite is not True: + given = second_args + gradfun = jax.grad(f, 0) + + def frozen_grad(y): + return gradfun(first_args, y) + else: + given = first_args + gradfun = jax.grad(f, 1) + + def frozen_grad(x): + return gradfun(x, second_args) + + return jax.linearize(frozen_grad, given)[1] + + +def make_lagrangian(obj_func, breg_min=default_breg, breg_max=default_breg, + min_inequality_constraints=default_func, + min_equality_constraints=default_func, max_inequality_constraints=default_func, + max_equality_constraints=default_func): + """Transform the original constrained minimax problem with parametric inequalities into another minimax problem with only set constraints + + Args: + obj_func (callable): multivariate callable with signature `f(x,y, *args, **kwargs)` + breg_min (Named tuples of callable): Tuple of unary callables with signature + 'BregmanPotential = collections.namedtuple("BregmanPotential", ["DP", "DP_inv", "D2P","D2P_inv"])' + where DP and DP_inv are unary callables with signatures + `DP(x,*args, **kwargs)`,'DP_inv(x,*arg,**kwarg)' and + D2P, D2P_inv are function of functions + (Given an x, returning linear transformation function + that can take in another vector to output hessian-vector product). + breg_max (Named tuples of callable): Tuple of unary callables + min_inequality_constraints (callable): Unary callable with signature `h(x, *args, **kwargs)` + min_equality_constraints (callable): Unary callable with signature `h(x, *args, **kwargs)` + max_inequality_constraints (callable): Unary callable with signature `g(y, *args, **kwargs)` + max_equality_constraints (callable): Unary callable with signature `g(y, *args, **kwargs)` + + Returns: + tuple: callables (init_multipliers, lagrangian, breg_min_aug, breg_max_aug) + """ + + def init_multipliers(params_min, params_max=None, key=random.PRNGKey(1), *args, **kwargs): + """Initialize multipliers for equality and inequality constraints for both players + + Args: + params_min: initialized input to the equality and ineuqality constraints for min player, 'x' + params_max: initialized input to the equality and ineuqality constraints for max player, 'y' + + Returns: + min_augmented (tuple): initialized min player with signature (original_min_param, multipliers_eq_max, multipliers_ineq_max) + max_augmented (tuple): initialized max player with signature (original_max_param, multipliers_eq_min, multipliers_ineq_min) + """ + + # Equality constraints' shape for both min and max player + h_min = jax.eval_shape(min_equality_constraints, params_min, *args, **kwargs) + h_max = jax.eval_shape(max_equality_constraints, params_max, *args, **kwargs) + # Make multipliers for equality constraints associated with both players, ie, '\lambda_x' and '\lambda_y' + multipliers_eq_min = tree_util.tree_map(lambda x: random.normal(key, x.shape), h_min) + multipliers_eq_max = tree_util.tree_map(lambda x: random.normal(key, x.shape), h_max) + + # Inequality constraints' shape for both min and max player + g_min = jax.eval_shape(min_inequality_constraints, params_min, *args, + **kwargs) # should be a tuple + g_max = jax.eval_shape(max_inequality_constraints, params_max, *args, + **kwargs) # should be a tuple + # Make multipliers for the constraints associated with both players, ie, '\mu_x' and '\mu_y' + multipliers_ineq_min = tree_util.tree_map(lambda x: random.normal(key, x.shape), g_min) + multipliers_ineq_max = tree_util.tree_map(lambda x: random.normal(key, x.shape), g_max) + + min_augmented = (params_min, multipliers_eq_max, multipliers_ineq_max) + max_augmented = (params_max, multipliers_eq_min, multipliers_ineq_min) + + return min_augmented, max_augmented + + + def lagrangian(minPlayer, maxPlayer): + obj_portion = obj_func(*[x for x in [minPlayer[0], maxPlayer[0]] if x is not None]) + min_eq_portion = 0 + max_eq_portion = 0 + min_ineq_portion = 0 + max_ineq_portion = 0 + + if maxPlayer[1] is None: + pass + else: + min_eq_portion = math.pytree_dot(min_equality_constraints(minPlayer[0]), maxPlayer[1]) + + if minPlayer[1] is None: + pass + else: + max_eq_portion = math.pytree_dot(max_equality_constraints(maxPlayer[0]), minPlayer[1]) + + if maxPlayer[2] is None: + pass + else: + min_ineq_portion = math.pytree_dot(min_inequality_constraints(minPlayer[0]), + maxPlayer[2]) + + if minPlayer[2] is None: + pass + else: + max_ineq_portion = math.pytree_dot(max_inequality_constraints(maxPlayer[0]), + minPlayer[2]) + + # out = obj_func(minPlayer[0], maxPlayer[0]) +\ + # math.pytree_dot(min_equality_constraints(minPlayer[0]), maxPlayer[1]) +\ + # math.pytree_dot(max_equality_constraints(maxPlayer[0]), minPlayer[1]) +\ + # math.pytree_dot(min_inequality_constraints(minPlayer[0]), maxPlayer[2]) +\ + # math.pytree_dot(max_inequality_constraints(maxPlayer[0]), minPlayer[2]) + return obj_portion + min_eq_portion + max_eq_portion + min_ineq_portion + max_ineq_portion + + # DP_eq_min = lambda v: jax.tree_map(lambda x: x, v) + # DP_ineq_min = lambda v: jax.tree_map(DP_pd, v) + min_augmented_DP = (breg_min.DP, lambda v: jax.tree_map(lambda x: x, v), + lambda v: jax.tree_map(DP_pd, v)) # [breg_min.DP, DP_eq_min, DP_ineq_min] + + # DP_inv_eq_min = lambda v: jax.tree_map(lambda x: x, v) + # DP_inv_ineq_min = lambda v: jax.tree_map(DP_inv_pd,v) + min_augmented_DP_inv = (breg_min.DP_inv, lambda v: jax.tree_map(lambda x: x, v), + lambda v: jax.tree_map(DP_inv_pd, + v)) # [breg_min.DP_inv, DP_inv_eq_min, DP_inv_ineq_min] + + # D2P_eq_min = D2P_l2 + # D2P_ineq_min = lambda v: jax.tree_map(D2P_pd,v) + min_augmented_D2P = (breg_min.D2P, D2P_l2, lambda v: jax.tree_map(D2P_pd, + v)) # [breg_min.D2P, D2P_eq_min, D2P_ineq_min] + + # inv_D2P_eq_min = D2P_l2 + # inv_D2P_ineq_min = lambda v: jax.tree_map(inv_D2P_pd, v) + min_augmented_D2P_inv = (breg_min.inv_D2P, D2P_l2, lambda v: jax.tree_map(inv_D2P_pd, + v)) # [breg_min.inv_D2P, inv_D2P_eq_min, inv_D2P_ineq_min] + + # DP_eq_max = lambda x: x + # DP_ineq_max = lambda v: jax.tree_map(DP_pd, v) + max_augmented_DP = (breg_max.DP, lambda v: jax.tree_map(lambda x: x, v), + lambda v: jax.tree_map(DP_pd, v)) # [breg_min.DP, DP_eq_min, DP_ineq_min] + + # DP_inv_eq_min = lambda v: jax.tree_map(lambda x: x, v) + # DP_inv_ineq_min = lambda v: jax.tree_map(DP_inv_pd,v) + max_augmented_DP_inv = (breg_max.DP_inv, lambda v: jax.tree_map(lambda x: x, v), + lambda v: jax.tree_map(DP_inv_pd, + v)) # [breg_min.DP_inv, DP_inv_eq_min, DP_inv_ineq_min] + + # D2P_eq_max = D2P_l2 + # D2P_ineq_max = lambda v: jax.tree_map(D2P_pd,v) + max_augmented_D2P = (breg_max.D2P, D2P_l2, lambda v: jax.tree_map(D2P_pd, + v)) # [breg_min.D2P, D2P_eq_min, D2P_ineq_min] + + # inv_D2P_eq_max =D2P_l2 + # inv_D2P_ineq_max = lambda v: jax.tree_map(inv_D2P_pd, v) + max_augmented_D2P_inv = (breg_max.inv_D2P, D2P_l2, lambda v: jax.tree_map(inv_D2P_pd, + v)) # [breg_max.inv_D2P, inv_D2P_eq_max, inv_D2P_ineq_max] + + return init_multipliers, lagrangian, BregmanPotential(min_augmented_DP, min_augmented_DP_inv, + min_augmented_D2P, + min_augmented_D2P_inv), BregmanPotential( + max_augmented_DP, max_augmented_DP_inv, max_augmented_D2P, max_augmented_D2P_inv) + + +CMDState = collections.namedtuple("CMDState", "minPlayer maxPlayer minPlayer_dual maxPlayer_dual") +UpdateState = collections.namedtuple("UpdateState", "del_min del_max") +_tree_apply = partial(jax.tree_multimap, lambda f, x: f(x)) + + +def updates(prev_state, hessian_xy=None, hessian_yx=None, grad_min=None, + grad_max=None, breg_min=default_breg, breg_max=default_breg, eta_min=1., eta_max=1., objective_func=None, + precond_b_min=False, precond_b_max=False): + """Equation (4). Given current position (prev_state), compute the updates (del_x,del_y) to the players in cmd algorithm for next position. + + Args: + prev_state (Named tuples of vectors): The current position of the players given by tuple + with signature 'CMDState(minPlayer maxPlayer minPlayer_dual maxPlayer_dual)' + breg_min (Named tuples of callable): Tuple of unary callables with signature + 'BregmanPotential = collections.namedtuple("BregmanPotential", ["DP", "DP_inv", "D2P","D2P_inv"])' + where DP and DP_inv are unary callables with signatures + `DP(x,*args, **kwargs)`, 'DP_inv(x,*arg,**kwarg)' and + D2P, D2P_inv are function of functions + (Given an x, returning linear transformation function + that can take in another vector to output hessian-vector product). + breg_max (Named tuples of callable): Tuple of unary callables as 'breg_min'. + eta_min (scalar): User specified step size for min player. Default 1e-4. + eta_max (scalar): User specified step size for max player. Default 1e-4. + hessian_xy (callable): The (estimated) mixed hessian of the current positions of the players, represented in a matrix-vector operator from jax.jvp + hessian_xy (callable): The (estimated) mixed hessian of the current positions of the players, represented in a matrix-vector operator from jax.jvp + grad_min (vector): The (estimated) gradient of the cost function w.r.t. the max player parameters at current position. + grad_max(vector): The (estimated) gradient of the cost function w.r.t. the max player parameters at current position. + Returns: + UpdateState(del_min, del_max), a named tuple for the updates + """ + if objective_func is not None: + # grad_min_func = jit(jacfwd(objective_func, 0)) + # grad_max_func = jit(jacfwd(objective_func, 1)) + # H_xy_func = jit(jacfwd(grad_min, 1)) + # H_yx_func =jit(jacfwd(grad_max, 0)) + + # Compute current gradient for min and max players + grad_min = jacfwd(objective_func, 0)(prev_state.minPlayer, prev_state.maxPlayer) + grad_max = jacfwd(objective_func, 1)(prev_state.minPlayer, prev_state.maxPlayer) + + # Define the mixed hessian-vector product linear operator at current position + def hessian_xy(tangent): + return make_mixed_jvp(objective_func, prev_state.minPlayer, prev_state.maxPlayer)( + tangent) + + def hessian_yx(tangent): + return make_mixed_jvp(objective_func, prev_state.minPlayer, prev_state.maxPlayer, True)( + tangent) + + def linear_opt_min(min_tree): + temp = hessian_yx(min_tree) # returns max_tree type + temp1 = _tree_apply(_tree_apply(breg_max.inv_D2P, prev_state.maxPlayer), + temp) # returns max_tree type + temp2 = hessian_xy(temp1) # returns min_tree type + temp3 = tree_util.tree_map(lambda x: eta_max * x, temp2) # still min_tree type + temp4 = _tree_apply(_tree_apply(breg_min.D2P, prev_state.minPlayer), + min_tree) # also returns min_tree type + temp5 = tree_util.tree_map(lambda x: 1 / eta_min * x, temp4) + # print("linear operator being called! - min") + out = tree_util.tree_multimap(lambda x, y: x + y, temp3, temp5) + return out # min_tree type + + def linear_opt_max(max_tree): + temp = hessian_xy(max_tree) + temp1 = _tree_apply(_tree_apply(breg_min.inv_D2P, prev_state.minPlayer), temp) + temp2 = hessian_yx(temp1) # returns max_tree type + temp3 = tree_util.tree_map(lambda x: eta_min * x, temp2) # max_tree type + temp4 = _tree_apply(_tree_apply(breg_max.D2P, prev_state.maxPlayer), max_tree) + temp5 = tree_util.tree_map(lambda x: 1 / eta_max * x, temp4) # max_tree type + # print("linear operator being called! - max") + out = tree_util.tree_multimap(lambda x, y: x + y, temp3, temp5) # max_tree type + return out + + # calculate the vectors in equation (4) + temp = hessian_xy(_tree_apply(_tree_apply(breg_max.inv_D2P, prev_state.maxPlayer), grad_max)) + temp2 = tree_util.tree_map(lambda x: eta_max * x, temp) + vec_min = tree_util.tree_multimap(lambda arr1, arr2: arr1 + arr2, grad_min, temp2) + + if precond_b_min: + # vec_min_tree, min_tree_def = tree_util.tree_flatten(vec_min) + # cond_min = tree_util.tree_unflatten(min_tree_def, + # jax.tree_map(lambda x: jnp.linalg.norm(x, jnp.inf), vec_min_tree)) + # vec_min = tree_util.tree_multimap(lambda x, y: x / y, vec_min, cond_min) + cond_min = max(jax.tree_map(lambda x: jnp.linalg.norm(x, jnp.inf), tree_util.tree_flatten(vec_min)[0])) + vec_min = tree_util.tree_map(lambda x: x / cond_min, vec_min) + + # temp = _tree_apply(hessian_yx, _tree_apply(_tree_apply(breg_min.inv_D2P, prev_state.minPlayer), grad_min)) + temp = hessian_yx(_tree_apply(_tree_apply(breg_min.inv_D2P, prev_state.minPlayer), grad_min)) + temp2 = tree_util.tree_map(lambda x: eta_min * x, temp) + vec_max = tree_util.tree_multimap(lambda x, y: x - y, grad_max, temp2) + + if precond_b_max: + # vec_max_tree, max_tree_def = tree_util.tree_flatten(vec_max) + # cond_max = tree_util.tree_unflatten(max_tree_def, + # jax.tree_map(lambda x: jnp.linalg.norm(x), vec_max_tree)) + # vec_max = tree_util.tree_multimap(lambda x, y: x / y, vec_max,cond_max) + cond_max = max(jax.tree_map(lambda x: jnp.linalg.norm(x), tree_util.tree_flatten(vec_max)[0])) + vec_max = tree_util.tree_map(lambda x: x / cond_max, vec_max) + + + update_min, status_min = linalg.cg(linear_opt_min, vec_min, maxiter=1000) + + if precond_b_min: + update_min = tree_util.tree_map(lambda x: cond_min * x, update_min) + # update_min = tree_util.tree_multimap(lambda x, y: y * x, cond_min, update_min) + + update_min = tree_util.tree_map(lambda x: -x, update_min) # negation here! + + update_max, status_max = linalg.cg(linear_opt_max, vec_max, maxiter=1000) + if precond_b_max: + update_max = tree_util.tree_map(lambda x: cond_max * x, update_max) + # update_max = tree_util.tree_multimap(lambda x, y: y * x, cond_max, update_max) + + return UpdateState(update_min, update_max) + + +def cmd_step(prev_state, updates, breg_min=default_breg, breg_max=default_breg): + """Equation (2). Take in the previous player positions and update to the next player position. Return a 1-step cmd update. + + Args: + prev_state (Named tuples of vectors): The current position of the players given by tuple + with signature 'CMDState(minPlayer maxPlayer minPlayer_dual maxPlayer_dual)' + updates (Named tuples of vectors): The updates del_x,del_y computed from updates(...) with signature 'UpdateState(del_min, del_max)' + breg_min (Named tuples of callable): Tuple of unary callables with signature + 'BregmanPotential = collections.namedtuple("BregmanPotential", ["DP", "DP_inv", "D2P","D2P_inv"])' + where DP and DP_inv are unary callables with signatures + `DP(x,*args, **kwargs)`,'DP_inv(x,*arg,**kwarg)' and + D2P, D2P_inv are function of functions + (Given an x, returning linear transformation function + that can take in another vector to output hessian-vector product). + breg_max (Named tuples of callable): Tuple of unary callables as 'breg_min'. + + Returns: + Named tuple: the states of the players at current iteration - CMDState + """ + temp_min = _tree_apply(_tree_apply(breg_min.D2P, prev_state.minPlayer), updates.del_min) + temp_max = _tree_apply(_tree_apply(breg_max.D2P, prev_state.maxPlayer), updates.del_max) + + dual_min = tree_util.tree_multimap(lambda x, y: x + y, prev_state.minPlayer_dual, temp_min) + dual_max = tree_util.tree_multimap(lambda x, y: x + y, prev_state.maxPlayer_dual, temp_max) + + minP = _tree_apply(breg_min.DP_inv, dual_min) + maxP = _tree_apply(breg_max.DP_inv, dual_max) + + return CMDState(minP, maxP, dual_min, dual_max) diff --git a/fax/competitive/cmd/cmd_helper.py b/fax/competitive/cmd/cmd_helper.py new file mode 100644 index 0000000..55d0d52 --- /dev/null +++ b/fax/competitive/cmd/cmd_helper.py @@ -0,0 +1,136 @@ +import jax.numpy as np +import jax +from jax import grad, jvp +from jax.scipy import linalg +from jax import random +from functools import partial +from jax.config import config + +config.update("jax_enable_x64", True) + +# DP helper functions +def DP_hand(vec, nx): + temp = np.reshape(vec, (nx, nx)) + return (-np.linalg.inv(temp).T + temp).reshape(nx**2, 1) + + +# matrix PD Potential: P(M) = -logdet(M) + 1/2*norm(M)^2 +# Vector PD Potential: P(x) = xlog(x) +def matrix_DP_pd(M): + return -np.linalg.slogdet(M)[1] + + +def vector_DP_pd(v): + return np.dot(v, np.log(v)) + + +def DP_pd(v): + m = len(np.shape(v)) + if m == 1: + out = grad(lambda x: np.dot(x, np.log(x)))(v) + else: + out = grad(lambda M: -np.linalg.slogdet(M)[1])(v) + return out + + +# DP_inv helper functions + +def vector_DP_inv_pd(v): + return np.exp(v - np.ones_like(v)) + + +def DP_inv_pd(v): + m = len(np.shape(v)) + if m == 1: + out = vector_DP_inv_pd(v) + else: + out = -linalg.inv(v).T + return out + +# def matrix_DP_inv_pd(Y): +# # Y = (Y+Y.T)/2 +# nx,_ = np.shape(Y) +# s,U = linalg.eigh(Y) +# s_x = np.empty_like(s) +# for i in range(len(s)): +# y = float(np.real(np.roots([1,-s[i],-1])[np.roots([1,-s[i],-1])>0])) +# print(y) +# s_x = index_update(s_x, i, y) +# print(s_x) +# TT = (U @ np.diag(s_x) @ linalg.inv(U)) +# b, _ = np.linalg.eig(TT.reshape((nx, nx))) +# print(b) +# return TT + + +# D2P helper functions +def id_func(x): + return lambda u: np.dot(np.identity(x.shape[0]), u) + + +def hvp(f, primals, tangents): + return jvp(grad(f), primals, tangents)[1] + + +def D2P_pd(v): + m = len(np.shape(v)) + if m == 1: + def out(u): + return hvp(vector_DP_pd, (v,), (u,)) + else: + def out(u): + return hvp(matrix_DP_pd, (v,), (u,)) + return out + + +# inv_D2P helper functions + +def inv_D2P_pd(v): + m = len(np.shape(v)) + if m == 1: + def out(u): + return np.dot(np.diag(v), u) + else: + def out(u): + return np.dot(np.linalg.matrix_power(v, 2).T, u) + return out + +# Testing # +# +# DP_inv_eq_min = lambda v: jax.tree_map(lambda x: x, v) +# DP_inv_ineq_min = lambda v: jax.tree_map(DP_inv_pd, v) +# +# min_augmented_DP = (lambda x: x, lambda v: jax.tree_map(DP_pd, v)) # [breg_min.DP, DP_eq_min, DP_ineq_min] +# min_augmented_DP_inv = (DP_inv_eq_min, DP_inv_ineq_min) # [breg_min.DP_inv, DP_inv_eq_min, DP_inv_ineq_min] +# +# D2P_eq_min = lambda v: jax.tree_map(id_func, v) +# D2P_ineq_min = lambda v: jax.tree_map(D2P_pd, v) +# min_augmented_D2P = ( D2P_eq_min, D2P_ineq_min) +# +# # inv_D2P_eq_min = lambda v: jax.tree_map(lambda x: x, v) +# inv_D2P_eq_min = lambda v: jax.tree_map(id_func, v) +# inv_D2P_ineq_min = lambda v: jax.tree_map(inv_D2P_pd, v) +# min_augmented_D2P_inv = (inv_D2P_eq_min, inv_D2P_ineq_min) +# +# +# key1 = random.PRNGKey(0) +# key = random.PRNGKey(1) +# x1 = np.array([1., 2., 3.,4., 5.]) +# x2 = random.normal(key1, (5,)) +# W1 = random.normal(key, (3,3)) +# W2 = random.normal(key1, (3,3)) +# +# x = ((x1,x2), (x1,x1,x2)) +# W = ((W1,W2), (W1,W1,W2)) +# +# print(jax.tree_multimap(lambda f, x: f(x), min_augmented_DP, x)) +# print(DP_pd(x1)) +# print(DP_pd(x2)) +# +# +# # Check if the inv(D2P) match the closed form. +# print(inv_D2P_pd(W1)(np.identity(W1.shape[0]))) +# print(np.linalg.matrix_power(W1,2).T) +# +# print(inv_D2P_pd(x2)(x1)) +# print(np.dot(np.diag(x2),x1)) diff --git a/fax/competitive/cmd/cmd_tester.py b/fax/competitive/cmd/cmd_tester.py new file mode 100644 index 0000000..e8acec7 --- /dev/null +++ b/fax/competitive/cmd/cmd_tester.py @@ -0,0 +1,650 @@ +import jax.numpy as np +from jax import random, grad, jacfwd +from cmd import make_pd_bregman +import jaxlib +from jax.scipy import linalg +from cmd_helper import DP_pd, DP_inv_pd, D2P_pd, inv_D2P_pd +from cmd import make_lagrangian, updates, cmd_step, _tree_apply, make_bound_breg +import collections +from lq_game_helper import proj, gradient, Df_lambda, Df_L, Df_lambda_L, Df_L_lambda +import jax.ops +import pickle +from scipy.optimize import minimize +import matplotlib.pyplot as plt +from jax import grad, jacrev +import jax +from jax import jvp +import jax.numpy as jnp +from jax import jit, vmap, lax +print(jax.__version__) +print(jaxlib.__version__) +from jax.config import config +config.update("jax_enable_x64", True) + +BregmanPotential = collections.namedtuple("BregmanPotential", ["DP", "DP_inv", "D2P", "inv_D2P"]) +CMDState = collections.namedtuple("CMDState", "minPlayer maxPlayer minPlayer_dual maxPlayer_dual") +UpdateState = collections.namedtuple("UpdateState", "del_min del_max") + +key1 = random.PRNGKey(0) +key = random.PRNGKey(1) +x1 = jnp.array([1., 2., 3., 4., 5.]) +x2 = random.normal(key1, (5, )) +W1 = random.normal(key, (3, 3)) +W2 = random.normal(key1, (3, 3)) + +x = [(x1, x2), (x1, x1, x2)] +W = [(W1, W2), (W1, W1, W2)] + + +DP_inv_eq_min = lambda v: jax.tree_map(lambda x: x, v) +DP_inv_ineq_min = lambda v: jax.tree_map(DP_inv_pd, v) + +min_augmented_DP = (lambda x: x, lambda v: jax.tree_map(DP_pd, v)) # [breg_min.DP, DP_eq_min, DP_ineq_min] +min_augmented_DP_inv = (DP_inv_eq_min, DP_inv_ineq_min) # [breg_min.DP_inv, DP_inv_eq_min, DP_inv_ineq_min] + +D2P_eq_min = lambda v: jax.tree_map(id_func, v) +D2P_ineq_min = lambda v: jax.tree_map(D2P_pd, v) +min_augmented_D2P = ( D2P_eq_min, D2P_ineq_min) + +# inv_D2P_eq_min = lambda v: jax.tree_map(lambda x: x, v) +inv_D2P_eq_min = lambda v: jax.tree_map(id_func, v) +inv_D2P_ineq_min = lambda v: jax.tree_map(inv_D2P_pd, v) +min_augmented_D2P_inv = (inv_D2P_eq_min, inv_D2P_ineq_min) + + +key1 = random.PRNGKey(0) +key = random.PRNGKey(1) +x1 = jnp.array([1., 2., 3.,4., 5.]) +x2 = random.normal(key1, (5,)) +W1 = random.normal(key, (3,3)) +W2 = random.normal(key1, (3,3)) + +x = ((x1,x2), (x1,x1,x2)) +W = ((W1,W2), (W1,W1,W2)) + +print(jax.tree_multimap(lambda f, x: f(x), min_augmented_DP, x)) +print(DP_pd(x1)) +print(DP_pd(x2)) + + +# Check if the inv(D2P) match the closed form. +print(inv_D2P_pd(W1)(jnp.identity(W1.shape[0]))) +print(jnp.linalg.matrix_power(W1,2).T) + +print(inv_D2P_pd(x2)(x1)) +print(jnp.dot(jnp.diag(x2),x1)) + + + + + +def make_bound_breg_original(lb=-1.0, ub=1.0): + + def breg_bound_internal(lb, ub, *args, **kwargs): + return lambda vec: jnp.sum( + (- vec + ub) * jnp.log(- vec + ub) + (vec - lb) * jnp.log(vec - lb)) + + def DP_bound_internal(lb, ub): + return jax.grad(breg_bound_internal(lb, ub)) + + def DP_inv_bound_internal(lb, ub): + return lambda vec: (ub * jnp.exp(vec) + lb) / (1 + jnp.exp(vec)) + + def D2P_bound_internal(lb, ub): + def out(vec): + return lambda u: jvp(DP_bound_internal(lb, ub), (vec,), (u,))[1] + + return out + + def inv_D2P_bound_internal(lb, ub): + def out(vec): + if len(jnp.shape(vec)) <= 1: + return lambda u: jnp.dot(jnp.diag(1 / ((1 / (ub - vec)) + (1 / (vec - lb)))), u) + else: + return lambda u: (1 / ((1 / (ub - vec)) + (1 / (vec - lb)))) * u + return out + + return BregmanPotential(DP_bound_internal(lb, ub), DP_inv_bound_internal(lb, ub), + D2P_bound_internal(lb, ub), inv_D2P_bound_internal(lb, ub)) + + + +def make_bound_breg_vector_original_unscaled(lb=jnp.array([-1.0]), ub=jnp.array([1.0])): + def breg_bound_internal(vec): + assert vec.shape == lb.shape, "lower bound shape does not match variable shape!" + assert vec.shape == ub.shape, "upper bound shape does not match variable shape!" + return jnp.sum((- vec + ub) * jnp.log(- vec + ub) + (vec - lb) * jnp.log(vec - lb), axis=-1) + + DP_bound_internal = jax.grad(breg_bound_internal) + + def DP_inv_bound_internal(vec): + return (ub * jnp.exp(vec) + lb) / (1 + jnp.exp(vec)) + + def D2P_bound_internal(vec): + return lambda u: jvp(DP_bound_internal, (vec,), (u,))[1] + + def inv_D2P_bound_internal(vec): + if len(jnp.shape(vec)) >= 1: + return lambda u: jnp.dot(jnp.diag(1 / ((1 / (ub - vec)) + (1 / (vec - lb)))), u) + else: + return lambda u: (1 / ((1 / (ub - vec)) + (1 / (vec - lb)))) * u + + + return BregmanPotential(DP_bound_internal, DP_inv_bound_internal, + D2P_bound_internal, inv_D2P_bound_internal) + + +x = random.normal(key1) +eta = jnp.array([1e-2]) + +bregman_potential = make_bound_breg(step_size=eta) +bregman_potential_original = make_bound_breg_vector_original_unscaled() #make_bound_breg_original() +assert bregman_potential.DP(jnp.array([x])) == 1/eta* bregman_potential_original.DP(jnp.array([x])), "DP not matching!" +assert bregman_potential.DP_inv(jnp.array([x])) == bregman_potential_original.DP_inv(eta* jnp.array([x])), "DP_inv not matching!" +assert jnp.isclose(bregman_potential.D2P(jnp.array([x]))(jnp.array([1.2345])) , 1/eta * bregman_potential_original.D2P(jnp.array([x]))(jnp.array([1.2345]))), "D2P not matching!" +assert jnp.isclose(bregman_potential.inv_D2P(jnp.array([x]))(jnp.array([1.2345])) , eta* bregman_potential_original.inv_D2P(jnp.array([x]))(jnp.array([1.2345]))), "D2P not matching!" + +lb = -1.0 * jnp.ones((5,)) +ub = 1.0 * jnp.ones((5,)) +# eta = jnp.array([1e-2,2e-3,7e-1,4e-4,3e-2]) +eta = 1.234e-2 +vector_eta = eta * jnp.ones((5,)) +bregman_potential = make_bound_breg(lb,ub, vector_eta) +bregman_potential_original = make_bound_breg_original(lb,ub) + +assert jnp.isclose(bregman_potential.DP(x2) , 1/eta * bregman_potential_original.DP(x2)).all(), "DP not matching!" +assert jnp.isclose(bregman_potential.DP_inv(x2) , bregman_potential_original.DP_inv(eta* x2)).all(), "DP_inv not matching!" +assert jnp.isclose(bregman_potential.D2P(x2)(x2) , 1/eta * bregman_potential_original.D2P(x2)(x2)).all(), "D2P not matching!" +assert jnp.isclose(bregman_potential.inv_D2P(x2)(x2) , eta* bregman_potential_original.inv_D2P(x2)(x2)).all(), "D2P not matching!" + + +print(1 / eta * bregman_potential_original.DP(x2)) +print(bregman_potential.DP(x2)) + +# cmd main algorithm test with single variables, scalar example from cmd paper +"""" +def obj_func(x, y): + return 2 * x * y - (1 - y) ** 2 + +breg_min = pd_bregman() +breg_max = pd_bregman() +# initialize states +x_init = -random.normal(key1, (1, )) +y_init = -random.normal(key, (1, )) +x_init_dual = DP_pd(x_init) +y_init_dual = DP_pd(y_init) +prev_state = CMDState(x_init, y_init, x_init_dual, y_init_dual) + +# Compute hessians and gradients +grad_min = jacfwd(obj_func,0) +grad_max = jacfwd(obj_func,1) +H_xy = jacfwd(grad_min,1) +H_yx =jacfwd(grad_max,0) + +gradient_min = grad_min(prev_state.minPlayer, prev_state.maxPlayer).flatten() +gradient_max = grad_max(prev_state.minPlayer, prev_state.maxPlayer).flatten() +hessian_xy = lambda v: H_xy(prev_state.minPlayer, prev_state.maxPlayer).flatten()*v +hessian_yx = lambda v: H_yx(prev_state.minPlayer, prev_state.maxPlayer).flatten()*v + +# Main cmd algorithm +for t in range(1000): + delta = updates(prev_state, 0.001, 0.001, hessian_xy, hessian_yx, gradient_min, gradient_max, breg_min, breg_max) + print(prev_state) + new_state = cmd_step(prev_state, delta, breg_min, breg_max) + print(new_state) + prev_state = new_state + + gradient_min = grad_min(prev_state.minPlayer, prev_state.maxPlayer).flatten() + gradient_max = grad_max(prev_state.minPlayer, prev_state.maxPlayer).flatten() + hessian_xy = lambda v: H_xy(prev_state.minPlayer, prev_state.maxPlayer).flatten() * v + hessian_yx = lambda v: H_yx(prev_state.minPlayer, prev_state.maxPlayer).flatten() * v + +print(prev_state) +print(new_state) +""" + + +# cmd main algorithm test with single variables, vector example from cmd paper +""" +n = 3 +A = random.normal(key, (n, n)) +b = (A[:,0] + A[:,1]) / 2 + 0.01 + +def obj_func(x, y): + return np.linalg.norm(np.dot(A,x)-b) ** 2 + y * (np.dot(np.ones_like(b), x) - 1) + +breg_min = BregmanPotential(DP_pd, DP_inv_pd, D2P_pd, inv_D2P_pd) + +# initialize states +x_init = random.normal(key1, (n, )) +y_init = random.normal(key, (1, )) +x_init = jax.ops.index_update(x_init, x_init<0.01,0.01) +x_init_dual = DP_pd(x_init) +y_init_dual = y_init +prev_state = CMDState(x_init, y_init, x_init_dual, y_init_dual) + +# Compute hessians and gradients +grad_min = jacfwd(obj_func,0) +grad_max = jacfwd(obj_func,1) +H_xy = jacfwd(grad_min,1) +H_yx =jacfwd(grad_max,0) + +grad_min = grad_min(prev_state.minPlayer, prev_state.maxPlayer).reshape(n,) +grad_max = grad_max(prev_state.minPlayer, prev_state.maxPlayer).reshape(1,) +hessian_xy = lambda v: np.dot(H_xy(prev_state.minPlayer, prev_state.maxPlayer).reshape(n,1), v) +hessian_yx = lambda v: np.dot(H_yx(prev_state.minPlayer, prev_state.maxPlayer).reshape(1,n), v) + +# Main cmd algorithm +for t in range(1000): + delta = updates(prev_state, 0.01, 0.1, hessian_xy, hessian_yx, grad_min, grad_max, breg_min) + print(prev_state) + new_state = cmd_step(prev_state, delta, breg_min) + print(new_state) + prev_state = new_state + + grad_min = grad_min(prev_state.minPlayer, prev_state.maxPlayer).reshape(n,) + grad_max = grad_max(prev_state.minPlayer, prev_state.maxPlayer).reshape(1,) + hessian_xy = lambda v: np.dot(H_xy(prev_state.minPlayer, prev_state.maxPlayer).reshape(n, 1), v) + hessian_yx = lambda v: np.dot(H_yx(prev_state.minPlayer, prev_state.maxPlayer).reshape(1, n), v) + +print(prev_state) +print(new_state) +""" + + +# cmd main algorithm test with structured variables, vector example from RRL paper +# using the sampling functions from RRL repository. +# The min player has two variables, K and Lambda. The max player has a single variable L. +""" +# Problem Parameters +A = jnp.array([[1,1],[0,1]]) +B = jnp.array([[0],[1]]) +C = jnp.array([[0.5],[1]]) +nx,nu = B.shape +_,nw = C.shape +T = 15 +Ru = jnp.eye(nu) +Rw = 20*jnp.eye(nw) +Q = jnp.eye(nx) +q = 0.01 +e,_ = jnp.linalg.eig(Q) +l_max = (jnp.min(e) - q) / Rw +safeguard = 2 + + +# Bregmen Potential definitions +def D2P_l2(v): + return lambda x: x +breg_min = BregmanPotential((lambda x: x, DP_pd), (lambda x: x, DP_inv_pd), (D2P_l2, D2P_pd), (D2P_l2, inv_D2P_pd)) + +# Initialization of variables +K = 0.001 * random.normal(key1, (nu, nx)) +Lambda = 0.001 * random.normal(key1, (nx, nx)) +Lambda = jnp.eye(nx)+ Lambda + Lambda.T +Lambda = proj(Lambda,2) + +x = (K, Lambda) # min player +y = 0.01 * random.normal(key, (nu, nx)) # max player L +dual_x = (K, DP_pd(Lambda)) +dual_y = y +prev_state = CMDState(x, y, dual_x, dual_y) + +# # Get gradients and hessians +# print("-- about to compute gradients --") +# DK,DL,DKL = gradient(50,100,A,B,C,Q,Ru,Rw,prev_state.minPlayer[0],y,T) +# print("--done with gradient computation--") +# DfLambda = Df_lambda(prev_state.minPlayer[1],y,Q,q,Rw,nx) +# DfL = Df_L(prev_state.minPlayer[1],y,Q,q,Rw,nx) +# DfLambdaL = Df_lambda_L(prev_state.minPlayer[1],y,Q,q,Rw,nx) +# DfLLambda = Df_L_lambda(prev_state.minPlayer[1],y,Q,q,Rw,nx) +# +# grad_max = DL + DfL +# grad_min = (DK, DfLambda) +# # hessian_xy = lambda v: [np.matmul(DKL, v.T).T, np.tensordot(DfLambdaL, DL)] +def hessian_xy_generator(DKL,DfLambdaL): + def hessian_xy(max_tree): + return (np.matmul(DKL, max_tree.T).T, np.tensordot(DfLambdaL, max_tree)) # returns minPlayer structure + return hessian_xy + +# hessian_yx = lambda K_var,Lambda_var : np.matmul(DKL.T,K_var.T).T + np.tensordot(DfLLambda,Lambda_var) +def hessian_yx_generator(DKL,DfLLambda): + def hessian_yx(min_tree): + K_var = min_tree[0] + Lambda_var = min_tree[1] + return np.matmul(DKL.T,K_var.T).T + np.tensordot(DfLLambda,Lambda_var) # returns maxPlayer structure + return hessian_yx + +# Main cmd algorithm +state_list = [] +state_list.append(prev_state) +minPlayer_list_1 = [prev_state.minPlayer[0][0][0]] +minPlayer_list_2 = [prev_state.minPlayer[0][0][1]] +maxPlayer_list_1 = [prev_state.maxPlayer[0][0]] +maxPlayer_list_2 = [prev_state.maxPlayer[0][1]] + +# @jit +# def jit_cmd(prev_state): +# DK, DL, DKL= gradient(50, 100, A, B, C, Q, Ru, Rw, prev_state.minPlayer[0], prev_state.maxPlayer, T) +# DfLambda = Df_lambda(prev_state.minPlayer[1], prev_state.maxPlayer, Q, q, Rw, nx) +# DfL = Df_L(prev_state.minPlayer[1], prev_state.maxPlayer, Q, q, Rw, nx) +# DfLambdaL = Df_lambda_L(prev_state.minPlayer[1], prev_state.maxPlayer, Q, q, Rw, nx) +# DfLLambda = Df_L_lambda(prev_state.minPlayer[1], prev_state.maxPlayer, Q, q, Rw, nx) +# grad_max = DL + DfL +# grad_min = (DK, DfLambda) +# +# delta = updates(prev_state, 2e-4, 4e-3, hessian_xy_generator(DKL, DfLambdaL), +# hessian_yx_generator(DKL, DfLLambda), grad_min, grad_max, breg_min) +# return cmd_step(prev_state, delta, breg_min) + +eta_x = 1e-4 +eta_y = 1e-3 + +@jit +def jit_cmd(prev_state, DK, DL, DKL): + def hessian_xy_generator(DKL, DfLambdaL): + def hessian_xy(max_tree): + # print('hessian_xy called!') + return (np.matmul(DKL, max_tree.T).T, + np.tensordot(DfLambdaL, max_tree)) # returns minPlayer structure + + return hessian_xy + + # hessian_yx = lambda K_var,Lambda_var : np.matmul(DKL.T,K_var.T).T + np.tensordot(DfLLambda,Lambda_var) + def hessian_yx_generator(DKL, DfLLambda): + def hessian_yx(min_tree): + K_var = min_tree[0] + Lambda_var = min_tree[1] + # print('hessian_yx called!') + return np.matmul(DKL.T, K_var.T).T + np.tensordot(DfLLambda, + Lambda_var) # returns maxPlayer structure + + return hessian_yx + + DfLambda = Df_lambda(prev_state.minPlayer[1], prev_state.maxPlayer, Q, q, Rw, nx) + DfL = Df_L(prev_state.minPlayer[1], prev_state.maxPlayer, Q, q, Rw, nx) + DfLambdaL = Df_lambda_L(prev_state.minPlayer[1], prev_state.maxPlayer, Q, q, Rw, nx) + DfLLambda = Df_L_lambda(prev_state.minPlayer[1], prev_state.maxPlayer, Q, q, Rw, nx) + grad_max = DL + DfL + grad_min = (DK, DfLambda) + + delta = updates(prev_state, eta_x, eta_y, hessian_xy_generator(DKL, DfLambdaL), + hessian_yx_generator(DKL, DfLLambda), grad_min, grad_max, breg_min=breg_min, + precond_b_min=False, precond_b_max=False) + return cmd_step(prev_state, delta, breg_min=breg_min), delta, grad_min + + + +def P(Lambda, nx): + Lambda = np.reshape(Lambda, (nx, nx)) + # return np.trace(np.dot(Lambda_stack,LA.logm(Lambda_stack))) + sign, logdet = jax.numpy.linalg.slogdet(Lambda) + return -logdet + 0.5 * (jax.numpy.linalg.norm(Lambda, 'fro')) ** 2 # + 0.5*(np.linalg.norm(K,'fro'))**2 + + +DP = grad(P) +D2P = jacrev(DP) +# ---------------------------------------------- + +# infile = open('state_list.pkl','rb') +# state_list = pickle.load(infile) +# infile.close() +# prev_state = state_list[-1] +# prev_state = jax.tree_map(lambda x: jnp.float64(x), prev_state) + +print('------begin-------- starting with: ', prev_state) +for t in range(2000): + # print("-- about to compute gradients --") + # DK, DL, DKL = jit_gradient(prev_state) #gradient(50, 100, A, B, C, Q, Ru, Rw, prev_state.minPlayer[0], y, T) + # print("--done with gradient computation--") + # DfLambda = Df_lambda(prev_state.minPlayer[1], prev_state.maxPlayer, Q, q, Rw, nx) + # DfL = Df_L(prev_state.minPlayer[1], prev_state.maxPlayer, Q, q, Rw, nx) + # DfLambdaL = Df_lambda_L(prev_state.minPlayer[1], prev_state.maxPlayer, Q, q, Rw, nx) + # DfLLambda = Df_L_lambda(prev_state.minPlayer[1], prev_state.maxPlayer, Q, q, Rw, nx) + # grad_max = DL + DfL + # grad_min = (DK, DfLambda) + # delta = jit_updates(prev_state, DKL, DfLambdaL, DfLLambda, grad_min, grad_max) #updates(prev_state, 2e-5, 2e-5, hessian_xy_generator(DKL,DfLambdaL), hessian_yx_generator(DKL,DfLLambda), grad_min, grad_max, breg_min) + + + DK, DL, DKL= gradient(50, 250, A, B, C, Q, Ru, Rw, prev_state.minPlayer[0], prev_state.maxPlayer, T) + # new_state,del_, grad_min = jit_cmd(prev_state, DK, DL, DKL) + + # ------------without CMD package, hand compute everything-------- + DfLambda = Df_lambda(prev_state.minPlayer[1], prev_state.maxPlayer, Q, q, Rw, nx) + DfL = Df_L(prev_state.minPlayer[1], prev_state.maxPlayer, Q, q, Rw, nx) + DfLambdaL = Df_lambda_L(prev_state.minPlayer[1], prev_state.maxPlayer, Q, q, Rw, nx) + DfLLambda = Df_L_lambda(prev_state.minPlayer[1], prev_state.maxPlayer, Q, q, Rw, nx) + grad_max = DL + DfL + grad_min = (DK, DfLambda) + + + + x = jnp.vstack((prev_state.minPlayer[0].T, prev_state.minPlayer_dual[1].reshape(4, 1))) + y = prev_state.maxPlayer + + DflL = DfLambdaL.reshape((nx ** 2, nx)) + Dlambda = DfLambda + + Dx = np.vstack((DK.reshape((nx, nu)), Dlambda.reshape(nx ** 2, 1))) + Dy = DL + DfL + Dy = Dy.T + + grad_min = (DK.reshape((nx, nu)), Dlambda) # todo: check nu==1 + grad_max = Dy + + + + + Dxy = np.vstack((DKL, DflL)) + Dyx = Dxy.T + hessian = jax.scipy.linalg.block_diag(np.eye(nx), D2P(Lambda, nx).reshape((nx ** 2, nx ** 2))) + hessian_inv = np.linalg.inv(hessian) + + Jx = np.linalg.inv(1 / eta_x * hessian + eta_y * np.matmul(Dxy, Dyx)) + Jy = np.linalg.inv(1 / eta_y * np.eye(nx) + eta_x * np.matmul(np.matmul(Dyx, hessian_inv), Dxy)) + del_x = -np.matmul(Jx, (Dx + eta_y * np.matmul(np.matmul(Dxy, np.eye(nx)), Dy))) + del_y = np.matmul(Jy, (Dy - eta_x * np.matmul(np.matmul(Dyx, hessian_inv), Dx))) + + x = x + hessian @ del_x + y = y + del_y.T + + # Reassignment from dual variable to primal variables PRIMAL + K = x[0:nx, 0].T # first nx elements are K + K = np.minimum(np.maximum(K, -safeguard), safeguard) + K = K.reshape((nu, nx)) + dual_l = x[nx:, 0].reshape((nx, nx)) + dual_l = (dual_l + dual_l.T) / 2 + # dual_l = dual_l.reshape((nx ** 2, 1)) + Lambda = make_pd_bregman().DP_inv(dual_l) # get primal Lambda from the dual update + + + # -------------------------- safe guard ----------------- + # K = jnp.minimum(jnp.maximum(new_state.minPlayer[0], -safeguard), safeguard) + # new_state = CMDState((K, new_state.minPlayer[1]), new_state.maxPlayer, + # _tree_apply(breg_min.DP,(K, new_state.minPlayer[1])) ,new_state.maxPlayer_dual) + + K = jnp.minimum(jnp.maximum(K, -safeguard), safeguard) + new_state = CMDState((K, Lambda), y ,(K,dual_l) ,y) + + + + + + + + # Saving data + state_list.append(new_state) + minPlayer_list_1.append(new_state.minPlayer[0][0][0]) + minPlayer_list_2.append(new_state.minPlayer[0][0][1]) + maxPlayer_list_1.append( new_state.maxPlayer[0][0]) + maxPlayer_list_2.append(new_state.maxPlayer[0][1]) + + if t%20 ==0: + print("-------------------",t,"---------------") + print("K ",new_state.minPlayer[0]) + # print('grad_min ', grad_min) + # print('delta_min ', del_.del_min) + print("L ", new_state.maxPlayer) + # print('delta_max ', del_.del_max) + np.save("minPlayer_list_1",minPlayer_list_1) + np.save("minPlayer_list_2",minPlayer_list_2) + np.save("maxPlayer_list_1",maxPlayer_list_1) + np.save("maxPlayer_list_2",maxPlayer_list_2) + + with open('state_list_hand.pkl', 'wb') as f: + pickle.dump(state_list, f) + + + prev_state = new_state + + + +print(new_state) + + + +p1 = plt.figure(1) +plt.subplot(121) +plt.plot(minPlayer_list_1,label = 'CMD') +plt.legend() + +plt.subplot(122) +plt.plot(minPlayer_list_2,label = 'CMD') +plt.legend() +plt.title('K') + +p2 = plt.figure(2) +plt.subplot(121) +plt.plot(maxPlayer_list_1, label = 'CMD') +plt.legend() + + +plt.subplot(122) +plt.plot(maxPlayer_list_2,label = 'CMD') +plt.title('L') +plt.legend() +plt.show() +""" + +# # Test lagrangian making portion, works! +# def obj_func(x, y): +# return (2 * x * y - (1 - y) ** 2)[0] +# +# +# breg_min = BregmanPotential(DP_pd, DP_inv_pd, D2P_pd, inv_D2P_pd) +# breg_max = BregmanPotential(DP_pd, DP_inv_pd, D2P_pd, inv_D2P_pd) +# +# init_multipliers, lagrangian, breg_min_aug, breg_max_aug = make_lagrangian(obj_func, breg_min, breg_max) +# min_P, max_P = init_multipliers(np.array([1.,]),np.array([2.,])) +# dual_min = _tree_apply(breg_min_aug.DP,min_P) +# dual_max = _tree_apply(breg_max_aug.DP,max_P) +# prev_state = CMDState(min_P,max_P, dual_min, dual_max) +# updates(prev_state,1e-4, 1e-4, breg_min=breg_min_aug, breg_max = breg_max_aug, objective_func=lagrangian) + + + +""" +horizon = 10 # how many unit time we simulate for +num_control_intervals = 20 # how many intervals of control +step_size = horizon/num_control_intervals # how long to hold each control value + +control_bounds = np.empty((num_control_intervals, 2)) +control_bounds[:] = [-0.75, 1.0] +# (^ this can stay an onp array) + +x0 = jnp.array([0., 1.]) # start state +xf = jnp.array([0., 0.]) # end state + +# Dynamics function +@jit +def f(x, u): + x0 = x[0] + x1 = x[1] + return jnp.asarray([(1. - x1**2) * x0 - x1 + u, x0]) + +# Instantaneous cost +@jit +def c(x, u): + return jnp.dot(x, x) + u**2 + +vector_c = jit(vmap(c)) + +# Integrate from the start state, using controls, to the final state +@jit +def integrate_fwd(us): + def rk4_step(x, u): + k1 = f(x, u) + k2 = f(x + step_size * k1/2, u) + k3 = f(x + step_size * k2/2, u) + k4 = f(x + step_size * k3 , u) + return x + (step_size/6)*(k1 + 2*k2 + 2*k3 + k4) + + def fn(carried_state, u): + one_step_forward = rk4_step(carried_state, u) + return one_step_forward, one_step_forward # (carry, y) + + last_state_and_all_xs = lax.scan(fn, x0, us) + return last_state_and_all_xs + +# Calculate cost over entire trajectory +@jit +def objective(us): + _, xs = integrate_fwd(us) + all_costs = vector_c(xs, us) + return jnp.sum(all_costs) + jnp.dot(x0, x0) # add in cost of start state (will make no difference) + +# Calculate defect of final state +@jit +def equality_constraints(us): + final_state, _ = integrate_fwd(us) + return final_state - xf + +rng = jax.random.PRNGKey(42) +# rng, rng_input = jax.random.split(rng) +initial_controls_guess = jax.random.uniform(rng, shape=(num_control_intervals,), minval=-0.76, maxval=0.9) + +constraints = ({'type': 'eq', + 'fun': equality_constraints, + 'jac': jax.jit(jax.jacrev(equality_constraints)) + }) + +options = {'maxiter': 500, 'ftol': 1e-6} + + + + +# Make Lagrangian out of the original OCP +key = jax.random.PRNGKey(1) + +# Generate Lagrangian-related functions and augmented Bregman divergence +init_multipliers, lagrangian, breg_min_aug, breg_max_aug = make_lagrangian(objective, breg_min = make_bound_breg(lb=-0.75, ub=1.0), min_equality_constraints=equality_constraints) + +# Initialize the augmented min player and max player +min_P, max_P = init_multipliers(initial_controls_guess,None,key) +dual_min = _tree_apply(breg_min_aug.DP,min_P) +dual_max = _tree_apply(breg_max_aug.DP,max_P) + +# Construct a CMD state +init_state = CMDState(min_P, max_P, dual_min, dual_max ) + + +# Testing tge Lagrangian funciton and the Bregman potentials +L = lagrangian(min_P, max_P) # L = objective(initial_controls_guess) + max_P[1] @ equality_constraints(initial_controls_guess) + +_tree_apply( breg_max_aug.DP_inv,max_P) +_tree_apply(_tree_apply( breg_max_aug.D2P,max_P),max_P) +_tree_apply(_tree_apply( breg_max_aug.inv_D2P,max_P),max_P) + +_tree_apply( breg_min_aug.DP_inv,min_P) +_tree_apply(_tree_apply( breg_min_aug.D2P,min_P),min_P) +_tree_apply(_tree_apply( breg_min_aug.inv_D2P,min_P),min_P) + +prev_state = init_state +for i in range(200): + delta = updates(prev_state,1e-3, 1e-3, breg_min=breg_min_aug, breg_max = breg_max_aug, objective_func=lagrangian) + new_state = cmd_step(prev_state, delta, breg_min_aug, breg_max_aug) + prev_state = new_state + if i%1 ==0: + print("---------------",i,"------------") + print(lagrangian(new_state.minPlayer, new_state.maxPlayer)) +print(new_state.minPlayer) +""" diff --git a/fax/competitive/cmd/lq_game_helper.py b/fax/competitive/cmd/lq_game_helper.py new file mode 100644 index 0000000..97e77b8 --- /dev/null +++ b/fax/competitive/cmd/lq_game_helper.py @@ -0,0 +1,143 @@ +import jax.scipy.linalg as LA +import jax.numpy as np +from jax import jacfwd, grad, random, ops +from jax.config import config + +config.update("jax_enable_x64", True) + +key = random.PRNGKey(0) + +# Input L as a ROW VECTOR +def f(Lambda,L,Q,q,Rw,nx): + return -np.trace(np.matmul(Lambda,(Rw*np.matmul(L.T,L)-Q-q*np.eye(nx)))) + +Df_lambda = grad(f,0) +Df_L = grad(f,1) +Df_lambda_L = jacfwd(Df_lambda,1) +Df_L_lambda = jacfwd(Df_L,0) + + +# def opt_LQR(A,B,Q,R,s): +# R[1,1] = s +# R[0,0] = 1.5-s +# X = LA.solve_discrete_are(A,B,R,Q) +# K = -np.matmul(np.matmul(np.matmul(LA.inv(Q+np.matmul(B.T,np.matmul(X,B))),B.T), X),A) +# return K + +#Done +def inf_cost(A,B,C,Q,Ru,Rv,K,L): + d,p = B.shape + a,b = C.shape + K = K.reshape((p,d)) + L = L.reshape((b,a)) + R = np.linalg.diag((Ru,Rv)) + B_t = np.hstack((B,C)) + cl_map = A + np.matmul(B,K) + np.matmul(C,L) + if np.amax(np.abs(LA.eigvals(cl_map))) < (1.0 - 1.0e-6): + cost = np.trace(LA.solve_discrete_are(A,B_t,Q,R)) + else: + #cost = float("inf") + cost = -20 + return cost + + +def get_g(batch_size,A,B,C,Q,Ru,Rv,K,L,T,baseline = 0): + # mini_batch is a single gradient(log sum derivative of pi), avg of this is ordinary gradient + # but here it is equivalent to g. + sigma_K = 5e-1 + sigma_L = 5e-1 + sigma_x = 1e-4 + nx, nu = B.shape + _, nw = C.shape + K = K.reshape((nu,nx)) + L = L.reshape((nw,nx)) + Q = np.kron(np.eye(T,dtype=int), Q) + Rv = np.kron(np.eye(T,dtype=int), Rv) + Ru = np.kron(np.eye(T, dtype=int), Ru) + + X = np.zeros((nx*(T+1),batch_size)) + # X[0:nx,:] = 0.2 * random.normal(key, shape=(nx,batch_size)) + X = ops.index_update(X, ops.index[0:nx,:], 0.2 * random.normal(key, shape=(nx,batch_size)) ) + + + U = np.zeros((nu*T,batch_size)) + W = np.zeros((nw*T,batch_size)) + Vu = sigma_K * random.normal(key, shape = (nu*T, batch_size)) # noise for U + Vw = sigma_L * random.normal(key, shape = (nw*T, batch_size)) # noise for W + + for t in range(T): + # U[t*nu:(t+1)*nu,:] = np.matmul(K,X[nx*t:nx*(t+1),:]) + Vu[t*nu:(t+1)*nu,:] + U = ops.index_update(U, ops.index[t*nu:(t+1)*nu,:], + np.matmul(K,X[nx*t:nx*(t+1),:]) + Vu[t*nu:(t+1)*nu,:]) + # W[t*nw:(t + 1) * nw, :] = np.matmul(L, X[nx * t:nx * (t + 1), :]) + Vw[t * nw:(t + 1) * nw, :] + W = ops.index_update(W, ops.index[t*nw:(t + 1) * nw, :], + np.matmul(L, X[nx * t:nx * (t + 1), :]) + Vw[t * nw:(t + 1) * nw, :]) + # X[nx*(t+1):nx*(t+2),:] = np.matmul(A,X[nx*t:nx*(t+1),:]) + np.matmul(B,U[t*nu:(t+1)*nu,:]).reshape((nx,batch_size)) +\ + # + np.matmul(C,W[t*nw:(t+1)*nw,:]).reshape((nx,batch_size)) + sigma_x * random.normal(key, shape=(nx, batch_size)) + X = ops.index_update(X, ops.index[nx*(t+1):nx*(t+2),:], + np.matmul(A,X[nx*t:nx*(t+1),:]) + np.matmul(B,U[t*nu:(t+1)*nu,:]).reshape((nx,batch_size)) + np.matmul(C,W[t*nw:(t+1)*nw,:]).reshape((nx,batch_size)) + sigma_x * random.normal(key, shape=(nx, batch_size))) + + X_cost = X[nx:,:] + reward = np.diagonal(np.matmul(X_cost.T,Q.dot(X_cost))) + np.diagonal(np.matmul(U.T,Ru.dot(U))) - np.diagonal(np.matmul(W.T,Rv.dot(W))) + new_baseline = np.mean(reward) + reward = reward.reshape((len(reward),1)) + + #DK portion + X_hat = X[:-nx,:] #taking only T = 0:T-1 for X for log gradient computation + outer_grad_log_K = np.einsum("ik, jk -> ijk",Vu,X_hat) # shape (a,b,c) means there are a of the (b,c) blocks. access (b,c) blocks via C[0,:,:] + outer_grad_log_L = np.einsum("ik, jk -> ijk", Vw, X_hat) + sum_grad_log_K =0 + sum_grad_log_L = 0 + for t in range(T): + sum_grad_log_K += outer_grad_log_K[nu * t:nu * (t + 1), nx * t:nx * (t + 1),:] # Summing all diagonal blocks. gives p by d by batch_size + sum_grad_log_L += outer_grad_log_L[nw * t:nw * (t + 1), nx * t:nx * (t + 1), :] + + + mini_batch_K = (1/sigma_K)**2 * ((reward-new_baseline).T*sum_grad_log_K) #mini_batch is p by d, same size as K + mini_batch_L = (1 /sigma_L) ** 2 * ((reward - new_baseline).T * sum_grad_log_L) # mini_batch is b by a/d, same size as K + # mini_batch_K = 2 * ((reward-new_baseline).T*sum_grad_log_K) #mini_batch is p by d, same size as K + # mini_batch_L = 2 * ((reward - new_baseline).T * sum_grad_log_L) # mini_batch is b by a/d, same size as K + # print(mini_batch_K[0,0,:]) + + temp = np.einsum('mnr,ndr->mdr', sum_grad_log_K.swapaxes(0,1),sum_grad_log_L) + batch_mixed_KL = (1/(sigma_K*sigma_L))**2 * ((reward-new_baseline).T*temp) + # print('---new---',sum_grad_log_K[:,:,10][0,0]) + + return np.mean(mini_batch_K,axis = 2),np.mean(mini_batch_L,axis = 2),np.mean(batch_mixed_KL,axis = 2),new_baseline + + + + + +def gradient( num_sample,batch_size,A,B,C,Q,Ru,Rv,K,L,T,baseline = 0): + nu,nx = K.shape + nw,nx = L.shape + DK_samples = np.zeros(shape = (nx,num_sample)) + DL_samples = np.zeros(shape=(nx, num_sample)) + Dxy_all = np.zeros(shape = (num_sample,nx,nx)) + for i in range(num_sample): + g,f,mixed,baseline= get_g(batch_size,A,B,C,Q,Ru,Rv,K,L,T,baseline) + # DK_samples[:, i] = g.flatten() + DK_samples = ops.index_update(DK_samples, ops.index[:, i], g.flatten()) + # DL_samples[:, i] = f.flatten() + DL_samples = ops.index_update(DL_samples, ops.index[:, i], f.flatten()) + # Dxy_all[i,:,:] = mixed + Dxy_all = ops.index_update(Dxy_all, ops.index[i,:,:], mixed) + + return np.mean(DK_samples,axis = 1).reshape((nu,nx)),\ + np.mean(DL_samples,axis = 1).reshape((nw,nx)),\ + np.mean(Dxy_all, axis = 0) + + +def proj(L,temp,lower = 0.5): + s,v = np.linalg.eig(L.T@L) + s = np.minimum(np.maximum(s,lower),temp) + return np.real((v@np.diag(s))@ v.T) + + +def proj_sgd(L,temp): + temp = np.sqrt(temp) + s = np.linalg.norm(L,2) + s = np.minimum(np.maximum(s,-temp),temp) + + return L/np.linalg.norm(L,2) * s \ No newline at end of file