You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Move jax/_src/interpreters/batching.py into its own BUILD rule
Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, and leads to improved build and iteration times.
Unfortunately this is not a clean build refactor, because batching depends on jax.lax, which in turn depends on batching. However, the problematic functions are only called within contexts where jax.lax is available for import.
We have a few options here:
1. Continue to bundle the batching.py source with the main build.
2. Build separately, but do the local import workaround in this CL (a pattern we use elsewhere).
3. Build this separately, but move some batching definitions into jax.lax for a more strict dependency graph. Or pass the `lax` namespace explicitly to the function at the call site.
I opted for (2) here because I judged the benefits of a refactored build to be worth the cost of localized impure dependencies, and the kind of refactoring in (3) would affect some downstream users.
PiperOrigin-RevId: 762110930
0 commit comments