Skip to content

Extracting connectivity graph from FLAX or JAX #1718

Answered by jheek
mattiasmar asked this question in Q&A
Discussion options

You must be logged in to vote

I think what you are looking for is jax.make_jaxpr this gives you all the jax operations. It has a list of operations eqns each equation has invars/outvars and a var has a .aval field which will typically have a shape so you can do things like:

jaxpr = jax.make_jaxpr(fn)(*inputs)
inputs = jaxpr.eqns[i].invars
print([x.aval.shape for x in inputs])

Replies: 2 comments 24 replies

Comment options

You must be logged in to vote
23 replies
@mattiasmar
Comment options

@marcvanzee
Comment options

@mattiasmar
Comment options

@marcvanzee
Comment options

@mattiasmar
Comment options

Answer selected by mattiasmar
Comment options

You must be logged in to vote
1 reply
@mattiasmar
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
5 participants