You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi flax community, I am new learning for flax and trying rewrite vae example with jit. I just add @jax.jit on the train_step. But I got the following error message:
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/data/flyingcat/flax/examples/vae_multicards/main.py", line 67, in <module>
app.run(main)
File "/data/flyingcat/minicoda3/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/data/flyingcat/minicoda3/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/data/flyingcat/flax/examples/vae_multicards/main.py", line 59, in main
train.train_and_evaluate(FLAGS.config)
File "/data/flyingcat/flax/examples/vae_multicards/train.py", line 198, in train_and_evaluate
state = train_step(state, batch, key, config.latents)
File "/data/flyingcat/minicoda3/lib/python3.10/site-packages/flax/nnx/nnx/graph.py", line 1043, in update_context_manager_wrapper
return f(*args, **kwargs)
File "/data/flyingcat/minicoda3/lib/python3.10/site-packages/flax/nnx/nnx/transforms/transforms.py", line 359, in jit_wrapper
out, output_state, output_graphdef = jitted_fn(
File "/data/flyingcat/minicoda3/lib/python3.10/site-packages/flax/nnx/nnx/transforms/transforms.py", line 158, in jit_fn
out = f(*args, **kwargs)
File "/data/flyingcat/flax/examples/vae_multicards/train.py", line 85, in train_step
grads = jax.grad(loss_fn)(state.params)
File "/data/flyingcat/flax/examples/vae_multicards/train.py", line 76, in loss_fn
recon_x, mean, logvar = models.model(latents).apply(
File "/data/flyingcat/flax/examples/vae_multicards/models.py", line 58, in __call__
mean, logvar = self.encoder(x)
File "/data/flyingcat/flax/examples/vae_multicards/models.py", line 32, in __call__
mean_x = nn.Dense(self.latents, name='fc2_mean')(x)
File "/data/flyingcat/minicoda3/lib/python3.10/site-packages/flax/linen/linear.py", line 256, in __call__
kernel = self.param(
File "/data/flyingcat/minicoda3/lib/python3.10/site-packages/jax/_src/nn/initializers.py", line 321, in init
named_shape = core.as_named_shape(shape)
TypeError: Shapes must be 1D sequences of concrete values of integer type, got (500, Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function jit_fn at /data/flyingcat/minicoda3/lib/python3.10/site-packages/flax/nnx/nnx/transforms/transforms.py:139 for jit. This concrete value was not available in Python because it depends on the value of the argument args[3].
how can fix this error when use jit for vae example?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi flax community, I am new learning for flax and trying rewrite vae example with jit. I just add
@jax.jit
on the train_step. But I got the following error message:how can fix this error when use jit for vae example?
Beta Was this translation helpful? Give feedback.
All reactions