Can we use TrainState with pad_shard_upad? #2411
-
I have a train_step that has TrainState as one of its argument I am using pad_shard_unpad like this:
My state:TrainState has optax optimizer in it. When I run my code. I am getting the following error:
|
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
Are you using any other JAX transformations in your code as well? This error happens when you are passing a JAX array as an argument to a jitted function, but you have specified that this argument is static. JAX arrays should not be passed as compile-time constants since they cannot be hashed. Here's a minimal example explaining what I mean: import jax
jax.jit(lambda x: x + 1, static_argnums=(0,))(jax.numpy.array([0, 1])) # same error
jax.jit(lambda x: x + 1)(jax.numpy.array([0, 1])) # works |
Beta Was this translation helpful? Give feedback.
-
Thanks @marcvanzee |
Beta Was this translation helpful? Give feedback.
Thanks @marcvanzee
I think I had a bug in my code. I was mistakenly passing rng which is jax.random.PRNGKey(seed) which is a devicearray and I was asking it to be static. devicearray cannot be passed as static arguments. I will close this issue now.