Skip to content

Conversation

@gnecula
Copy link
Collaborator

@gnecula gnecula commented Feb 11, 2026

The previous code was not handling the case of a symbolic batch dimension (one of the most common uses of shape polymorphism)

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @gnecula, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical issue in JAX's Pallas library where pallas_call_batching failed to correctly handle symbolic batch dimensions, a common use case for shape polymorphism. The changes ensure that dynamic grid dimensions are properly recognized and integrated into the grid mapping and cost estimation, preventing crashes and enabling robust export of Pallas operations with symbolic shapes.

Highlights

  • Enhanced Shape Polymorphism Support: Modified Pallas core logic to correctly identify and handle dynamic grid dimensions, ensuring proper behavior with symbolic shapes.
  • Improved Batching with Symbolic Dimensions: Updated pallas_call_batching to correctly incorporate symbolic batch dimensions into the grid mapping and adjust cost estimation logic accordingly.
  • Added Vmap Export Test: Introduced a new test case to validate the export of vmap-ped Pallas calls with symbolic shapes, specifically targeting the previously failing scenario.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • jax/_src/pallas/core.py
    • Updated get_grid_mapping to check if a grid dimension is not constant using not dim_check(d) in addition to checking for None.
  • jax/_src/pallas/pallas_call.py
    • Introduced axis_size_is_dynamic boolean to determine if axis_size is a dynamic dimension.
    • Modified batched_grid_mapping to use pallas_core.dynamic_grid_dim for the batch axis if axis_size_is_dynamic.
    • Updated cost estimation logic to avoid scaling by axis_size if it's dynamic.
    • Conditionally prepended axis_size to dynamic_grid_args if axis_size_is_dynamic.
  • tests/pallas/export_pallas_test.py
    • Added test_export_vmap to verify vmap functionality with symbolic shapes in Pallas calls, including a kernel and a batched add_vectors function.
Activity
  • No specific activity (comments, reviews, progress updates) was provided in the context.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses an issue with shape polymorphism in pallas_call_batching when dealing with symbolic batch dimensions. The changes introduced are logical and correct. They properly detect dynamic grid dimensions and adjust the grid mapping, cost estimation, and dynamic grid arguments. The addition of a new test case for vmap with symbolic shapes effectively validates the fix. Overall, the implementation is solid and enhances the shape polymorphism capabilities of Pallas.

@gnecula gnecula force-pushed the fix_pallas_vmap_shape_poly branch from 0ad8d7f to a07bba9 Compare February 11, 2026 12:04
@gnecula gnecula self-assigned this Feb 11, 2026
@gnecula gnecula requested a review from bchetioui February 11, 2026 12:07
@gnecula gnecula force-pushed the fix_pallas_vmap_shape_poly branch from a07bba9 to 979b73c Compare February 11, 2026 12:09
exp = exporter(x_info, x_info) # No crash

if jtu.device_under_test() == "tpu":
x = y = jnp.ones((4, 128, 128))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use values of m, n that are larger than the block_size? I'm not sure how well this is covered elsewhere---just checking.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think that we need that extra coverage for this tests specifically. But I am now wondering what happens if we use values that are not divisible by 128! Let me try.

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Feb 11, 2026
@gnecula gnecula force-pushed the fix_pallas_vmap_shape_poly branch from 979b73c to e113ecb Compare February 11, 2026 13:27
The previous code was not handling the case of a symbolic batch
dimension (one of the most common uses of shape polymorphism)
@gnecula gnecula force-pushed the fix_pallas_vmap_shape_poly branch from e113ecb to 52e58b9 Compare February 12, 2026 07:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kokoro:force-run pull ready Ready for copybara import and testing

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants