Skip to content

Move jax/_src/interpreters/batching.py into its own BUILD rule #28957

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 23, 2025

Conversation

copybara-service[bot]
Copy link

Move jax/_src/interpreters/batching.py into its own BUILD rule

Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, and leads to improved build and iteration times.

Unfortunately this is not a clean build refactor, because batching depends on jax.lax, which in turn depends on batching. However, the problematic functions are only called within contexts where jax.lax is available for import.

We have a few options here:

  1. Continue to bundle the batching.py source with the main build.
  2. Build separately, but do the local import workaround in this CL (a pattern we use elsewhere).
  3. Build this separately, but move some batching definitions into jax.lax for a more strict dependency graph. Or pass the lax namespace explicitly to the function at the call site.

I opted for (2) here because I judged the benefits of a refactored build to be worth the cost of localized impure dependencies, and the kind of refactoring in (3) would affect some downstream users.

@copybara-service copybara-service bot force-pushed the test_762110930 branch 2 times, most recently from 617afc3 to 86d6105 Compare May 23, 2025 15:53
Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, and leads to improved build and iteration times.

Unfortunately this is not a clean build refactor, because batching depends on jax.lax, which in turn depends on batching. However, the problematic functions are only called within contexts where jax.lax is available for import.

We have a few options here:

1. Continue to bundle the batching.py source with the main build.
2. Build separately, but do the local import workaround in this CL (a pattern we use elsewhere).
3. Build this separately, but move some batching definitions into jax.lax for a more strict dependency graph. Or pass the `lax` namespace explicitly to the function at the call site.

I opted for (2) here because I judged the benefits of a refactored build to be worth the cost of localized impure dependencies, and the kind of refactoring in (3) would affect some downstream users.

PiperOrigin-RevId: 762447323
@copybara-service copybara-service bot merged commit c2c55ae into main May 23, 2025
@copybara-service copybara-service bot deleted the test_762110930 branch May 23, 2025 16:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant