Introduce annotate_custom_sharding binding #9203
Open
+119
−10
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR adds a new binding API
annotate_custom_sharding
that allows annotating an existing tensor with a custom sharding IR node without modifying its data layout. This is useful for cases where a tensor has already been sharded withmark_sharding
but needs additional sharding annotations for compiler optimizations.Unlike the existing
mark_sharding
function,annotate_custom_sharding
only adds the annotation to the XLA IR without changing the underlying data distribution, enabling more flexible sharding strategies to be provided to XLA. This is particularly useful for introducing resharding annotations on already-sharded tensors.Use Case
There are instances where we want to provide an explicit annotation hint around a kernel with manual sharding. In this case, we are limited to introducing custom sharding hints to XLA prior to the manual resharding. For instance, if we have FSDP + TP, and we wish to gather all weights across the FSDP dimension prior to the kernel, this is not possible. This PR allows us to introduce such functionality and flexibility, by redefining the sharding spec associated with the IR prior to the manual sharding.