Skip to content

how to dump the operations of all the methods of a jitted class object #1906

Answered by ronghongbo
ronghongbo asked this question in Q&A
Discussion options

You must be logged in to vote

Thanks to the help of Bojan Nikolic, I was able to dump the IR in this way:

import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn
import jax.tree_util as tree_util
import jax.tools.jax_to_ir
import functools
from typing import Sequence
from functools import partial
from jax.lib import xla_client

# The original code
class MLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, x):
    for feat in self.features[:-1]:
      x = nn.relu(nn.Dense(feat)(x))
    x = nn.Dense(self.features[-1])(x)
    return x

model = MLP([12, 8, 4])
batch = jnp.ones((32, 10))
variables = model.init(jax.random.PRNGKey(0), batch)
output = model.apply(variables, batch…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by marcvanzee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant