Skip to content

Introduce annotate_custom_sharding binding #9203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

rpsilva-aws
Copy link
Collaborator

@rpsilva-aws rpsilva-aws commented May 19, 2025

This PR adds a new binding API annotate_custom_sharding that allows annotating an existing tensor with a custom sharding IR node without modifying its data layout. This is useful for cases where a tensor has already been sharded with mark_sharding but needs additional sharding annotations for compiler optimizations.

Unlike the existing mark_sharding function, annotate_custom_sharding only adds the annotation to the XLA IR without changing the underlying data distribution, enabling more flexible sharding strategies to be provided to XLA. This is particularly useful for introducing resharding annotations on already-sharded tensors.

Use Case

There are instances where we want to provide an explicit annotation hint around a kernel with manual sharding. In this case, we are limited to introducing custom sharding hints to XLA prior to the manual resharding. For instance, if we have FSDP + TP, and we wish to gather all weights across the FSDP dimension prior to the kernel, this is not possible. This PR allows us to introduce such functionality and flexibility, by redefining the sharding spec associated with the IR prior to the manual sharding.

@rpsilva-aws rpsilva-aws force-pushed the rpsilva_resharding_v2 branch from 32344f3 to a16b53e Compare May 20, 2025 05:01
@rpsilva-aws rpsilva-aws changed the title Part 1: Disambiguate custom sharding op for DeviceData IR nodes Part 1: Add annotate_custom_sharding API May 20, 2025
@rpsilva-aws rpsilva-aws force-pushed the rpsilva_resharding_v2 branch from a16b53e to 93e8632 Compare May 20, 2025 05:06
@rpsilva-aws rpsilva-aws changed the title Part 1: Add annotate_custom_sharding API Part 1: Introduce annotate_custom_sharding binding May 20, 2025
@rpsilva-aws rpsilva-aws force-pushed the rpsilva_resharding_v2 branch from 93e8632 to da5885d Compare May 21, 2025 05:51
@rpsilva-aws rpsilva-aws requested review from tengyifei and bhavya01 May 21, 2025 17:50
@rpsilva-aws rpsilva-aws marked this pull request as ready for review May 21, 2025 18:20
@rpsilva-aws rpsilva-aws changed the title Part 1: Introduce annotate_custom_sharding binding Introduce annotate_custom_sharding binding May 21, 2025
@rpsilva-aws
Copy link
Collaborator Author

Hey @tengyifei! Let me know if you're able to review this, or feel free to add anyone else. We have this as a requirement to unblock our use case above, and we might consider trying to cherry pick it to 2.7.1.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant