Can't get loss from a jitted function #1735
Replies: 6 comments 3 replies
-
If you want to output the loss you can return it from the jitted function and output it then. This is done in all our examples, I would recommend taking a look at the mnist example and the Annotated MNIST which contains more explanation. |
Beta Was this translation helpful? Give feedback.
-
I looked over the FLAX examples, but with my code: def loss_fn(params):
data = dataset[0]
#data = dataset[0][0]
rec, out = Autoencoder().apply({"params": params}, data)
#globals()["p"] = rec, out
loss = compute_metrics(rec, out)["loss"]
return jnp.array(loss)
grads = jax.grad(loss_fn)(state.params) I never actually return the loss, so I don't have a way to get the loss itself. Am I using the wrong Jax/Flax functions? |
Beta Was this translation helpful? Give feedback.
-
If you have taken a look at the examples, then you must have noticed that they don't use |
Beta Was this translation helpful? Give feedback.
-
Thanks, that seems to get the loss and the grad. However, this code state = train_state.TrainState.create(
apply_fn=Autoencoder().apply,
params=Autoencoder().init(rng, dataset[0])["params"],
tx=optax.sgd(0.01)
) which was in the main code (not in a function), was causing an OOM, so I wrapped that in a jitted main. However, now I can't print out the loss because I'm still inside jitted code. I tried creating the state in a jit function and getting the value into a non-jit function, but that still causes OOMs. Am I going about this the wrong way or is there just no way to do this? |
Beta Was this translation helpful? Give feedback.
-
I tried moving various things in and out of jit, but it seems the actual OOM comes from when this code isn't jitted loss, grads = jax.value_and_grad(loss_fn)(state.params) Because of this, I can't get the actual value of loss (because it needs to be jitted). |
Beta Was this translation helpful? Give feedback.
-
Like this (this isn't all the code, just the relevant part)? : @jit
def update_model(state, grads):
#loss, grads = jax.value_and_grad(loss_fn)(state.params)
return state.apply_gradients(grads=grads)
@jit
def train_step(state):
def loss_fn(params):
data = dataset[0]
rec, out = Autoencoder().apply({"params": params}, data)
loss = compute_metrics(rec, out)["loss"]
return jnp.array(loss)
loss, grads = jax.value_and_grad(loss_fn)(state.params)
return loss, update_model(state, grads)
optimizer = optim.sgd.GradientDescent(learning_rate=0.01)
EPOCHS = 10
@jit
def create_state():
return train_state.TrainState.create(
apply_fn=Autoencoder().apply,
params=Autoencoder().init(rng, dataset[0])["params"],
tx=optax.sgd(0.01)
)
def main():
state = create_state()
for epoch in range(EPOCHS):
bar = tqdm(range(len(dataset)))
for i in bar:
loss, state = train_step(state)
print(loss)
bar.set_description(f"Epoch {epoch}")
main() I jitted as much as possible, but the OOM is still coming from the |
Beta Was this translation helpful? Give feedback.
-
I jitted all the big code blocks in my code and it runs fine. However, when I do this, I also want to track the loss after each batch. The issue is that with jit, I can't coerce the Traceds to floats, so I can print them. Is there a way I can do this?
Beta Was this translation helpful? Give feedback.
All reactions