Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Differentiation of klujax primitives #12

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions klujax.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import jax.extend
import jax.numpy as jnp
import numpy as np
from jax import core, lax
from jax import lax
from jax.core import ShapedArray
from jax.interpreters import ad, batching, mlir
from jaxtyping import Array
Expand Down Expand Up @@ -105,10 +105,10 @@ def coo_mul_vec(Ai: Array, Aj: Array, Ax: Array, x: Array) -> Array:

# Primitives ==========================================================================

solve_f64 = core.Primitive("solve_f64")
solve_c128 = core.Primitive("solve_c128")
coo_mul_vec_f64 = core.Primitive("coo_mul_vec_f64")
coo_mul_vec_c128 = core.Primitive("coo_mul_vec_c128")
solve_f64 = jax.extend.core.Primitive("solve_f64")
solve_c128 = jax.extend.core.Primitive("solve_c128")
coo_mul_vec_f64 = jax.extend.core.Primitive("coo_mul_vec_f64")
coo_mul_vec_c128 = jax.extend.core.Primitive("coo_mul_vec_c128")

# Register XLA extensions ==============================================================

Expand Down Expand Up @@ -395,8 +395,6 @@ def coo_mul_vec_value_and_jvp(arg_values, arg_tangents):


# Backward Gradients through Transposition ============================================


@transpose_register(solve_f64)
@transpose_register(solve_c128)
def solve_transpose(ct, Ai, Aj, Ax, b):
Expand All @@ -419,7 +417,6 @@ def coo_mul_vec_transpose(ct, Ai, Aj, Ax, b):
return None, None, None, coo_mul_vec(Aj, Ai, Ax.conj(), ct) # = A.T@ct [= ct@A]
else:
dA = ct[Ai] * b[Aj]
dA = dA.reshape(dA.shape[0], -1).sum(-1) # not sure about this...
return None, None, dA, None


Expand Down Expand Up @@ -472,6 +469,6 @@ def coo_vec_operation_vmap(operation, vector_arg_values, batch_axes):
# b is now guaranteed to have shape (n_lhs, n_col, n_rhs)
result = operation(Ai, Aj, Ax, b)
result = result.reshape(*shape)
return result, -1
return result, result.ndim - 1

raise ValueError("invalid arguments for vmap")
29 changes: 29 additions & 0 deletions tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,35 @@ def test_vmap_fail(dtype, op):
jax.vmap(op, in_axes=(None, None, None, 0), out_axes=0)(Ai, Aj, Ax, b)


@log_test_name
@pytest.mark.parametrize("dtype", [np.float64])
@pytest.mark.parametrize("ops", [(klujax.solve, jsp.linalg.solve), (klujax.coo_mul_vec, lax.dot)]) # fmt: skip
def test_1d_derivative(dtype, ops):
op_sparse, op_dense = ops
Ai, Aj, Ax, b = _get_rand_arrs_1d(8, (n_col := 5), dtype=dtype)

p = 0.35778278

def fsparse(p):
x_sp = op_sparse(Ai, Aj, p * Ax, p * b)
return jnp.linalg.norm(x_sp)

def fdense(p):
A = jnp.zeros((n_col, n_col), dtype=Ax.dtype).at[Ai, Aj].add(Ax)
x = op_dense(p * A, p * b)
return jnp.linalg.norm(x)

J1 = jax.jacfwd(fsparse)(p)
J2 = jax.jacfwd(fdense)(p)

_log_and_test_equality(J1, J2)

J1 = jax.jacrev(fsparse)(p)
J2 = jax.jacrev(fdense)(p)

_log_and_test_equality(J1, J2)


def _get_rand_arrs_1d(n_nz, n_col, *, dtype, seed=33):
Axkey, Aikey, Ajkey, bkey = jax.random.split(jax.random.PRNGKey(seed), 4)
Ai = jax.random.randint(Aikey, (n_nz,), 0, n_col, jnp.int32)
Expand Down