Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Differentiation of klujax primitives #12

Closed
wants to merge 2 commits into from

Conversation

frankschae
Copy link

Hi!

@ferranAX and I are trying to fix the derivative issues in SAX (flaport/sax#42). We spotted that differentiation in klujax also fails at the moment.

import klujax
import jax.numpy as jnp
import jax

b = jnp.array([8, 45, -3, 3, 19])
A_dense = jnp.array(
    [
        [2, 3, 0, 0, 0],
        [3, 0, 4, 0, 6],
        [0, -1, -3, 2, 0],
        [0, 0, 1, 0, 0],
        [0, 4, 2, 0, 1],
    ]
)
Ai, Aj = jnp.where(jnp.abs(A_dense) > 0)
Ax = A_dense[Ai, Aj]

def myf(p):
    return jnp.sin(p)

def f1(p):
    result = jnp.linalg.inv(A_dense * myf(p)) @ b
    return jnp.sum(result**2)

def f2(p): 
    result = klujax.solve(Ai, Aj, Ax*myf(p), b)
    return jnp.sum(result**2)

p = 2.0
print(jnp.abs(f1(p) - f2(p)) < 1e-12)
print(f1(p), f2(p))

jax.grad(f1)(p), jax.jacfwd(f1)(p), jax.jacrev(f1)(p)
jax.grad(f2)(p) # works
jax.jacfwd(f2)(p) # works
jax.jacrev(f2)(p) # assertion error
---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation         
Cell In[50], line 32
     31 p = 2.0
---> 32 print(jnp.abs(f1(p) - f2(p)) < 1e-12)
     33 print(f1(p), f2(p))

Cell In[50], line 28, in f2()
     27 def f2(p): 
---> 28     result = klujax.solve(Ai, Aj, Ax*myf(p), b/myf(p))
     29     return jnp.sum(result**2)

File ~/GitRepos/klujax/klujax.py:56, in solve()
     55 else:
---> 56     result = solve_f64.bind(
     57         Ai.astype(jnp.int32),
     58         Aj.astype(jnp.int32),
     59         Ax.astype(jnp.float64),
     60         b.astype(jnp.float64),
     61     )
     63 return result

JaxStackTraceBeforeTransformation: AssertionError: (-1, 1)

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

AssertionError                            Traceback (most recent call last)
Cell In[54], line 1
----> 1 jax.jacrev(f2)(p)

File /opt/anaconda3/envs/sax-debug/lib/python3.12/site-packages/jax/_src/api.py:673, in jacrev.<locals>.jacfun(*args, **kwargs)
    671   y, pullback, aux = _vjp(f_partial, *dyn_args, has_aux=True)
    672 tree_map(partial(_check_output_dtype_jacrev, holomorphic), y)
--> 673 jac = vmap(pullback)(_std_basis(y))
    674 jac = jac[0] if isinstance(argnums, int) else jac
    675 example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args

    [... skipping hidden 33 frame]

File /opt/anaconda3/envs/sax-debug/lib/python3.12/site-packages/jax/_src/util.py:445, in tuple_insert(t, idx, val)
    444 def tuple_insert(t, idx, val):
--> 445   assert 0 <= idx <= len(t), (idx, len(t))
    446   return t[:idx] + (val,) + t[idx:]

AssertionError: (-1, 1)

The error happens due to the use of vmap when computing the Jacobian. The fix was also suggested in #11 for vmap in general.

@flaport
Copy link
Owner

flaport commented Feb 22, 2025

Hi @frankschae @ferranAX

I ran the tests and they're passing for jax<0.5. Is this ready to merge?

On that note... I'm not sure why the tests are failing for jax>=0.5. (They're failing on master too)

@flaport
Copy link
Owner

flaport commented Feb 24, 2025

Hey thanks for the PR! However, I did end up doing a huge refactor of klujax. Gradients and vmap should be working properly now.

@frankschae
Copy link
Author

Awesome! Thanks a lot!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants