Extracting connectivity graph from FLAX or JAX #1718
-
Hi, |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 24 replies
-
I think what you are looking for is
|
Beta Was this translation helpful? Give feedback.
-
Is it possible to visualize the computation graph using tensorboard, as
both jax/flax and tensorflow are from google, I suppose it will be much
easier to solve the incompatible part.
Actually, I believe I recall someone using jax2tf (which includes Flax
metadata because of Flax's use of named_call, which become tf.name_scopes)
and loading a graph into the TensorBoard visualization that looked pretty
good. If you try this out and it works, please report back! (and we can add
a documentation page about it)
…On Wed, Dec 15, 2021 at 10:39 AM jheek ***@***.***> wrote:
Tensorboard doesn't support a XLA computation graph it only shows a tf
graph. The haiku visualization module should work on any JAX function. And
I think it will also support the Flax metadata (so it will groups ops
belonging to the same flax Module).
—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub
<#1718 (reply in thread)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAAJFURW3ULOZ2NWOXHVLF3URBO3RANCNFSM5J4OHXFA>
.
Triage notifications on the go with GitHub Mobile for iOS
<https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675>
or Android
<https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub>.
|
Beta Was this translation helpful? Give feedback.
I think what you are looking for is
jax.make_jaxpr
this gives you all the jax operations. It has a list of operationseqns
each equation has invars/outvars and a var has a.aval
field which will typically have a shape so you can do things like: