diff --git a/flax/core/tracers.py b/flax/core/tracers.py index fe2ff874c0..9d4393e967 100644 --- a/flax/core/tracers.py +++ b/flax/core/tracers.py @@ -31,6 +31,7 @@ def current_trace(): return jax.core.get_opaque_trace_state(convention="flax") def check_trace_level(base_level): - level = current_trace() - if level != base_level: - raise errors.JaxTransformError() + pass + # level = current_trace() + # if level != base_level: + # raise errors.JaxTransformError()