Skip to content

feat: abstraction of xla::OpSharding proto using wrapper class #9467

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from

Conversation

kvshbg-aws
Copy link

This PR includes the changes related to abstracting xla::OpSharidng proto object into a torch_xla::OpSharding wrapper class.

This new class object will not have the requirements of xla::OpSharding (however, it will be an extension xla::OpSharding proto defined over here).
We have defined the wrapper class in torch/xla which will construct an xla::OpSharding object with additional fields such as global_device_ids/global_tile_assignment and will have forwarded/proxy functions to xla::OpSharding . These forwarded functions will help user still make use of the same xla::OpSharding APIs as they normally would. We can also define torch_xla specific functions in this wrapper class to further use the extra fields that were stored during the initialization of the OpSharding object. This approach also allows the flexibility of converting the torch_xla::OpSharding object back to xla::OpSharding while lowering into HLO, thus, giving user the flexibility to use the abstracted class (and other additional fields stored) anywhere in the code base as needed, this is particularly useful since the XLA's HLOs are 0th indexed, hence we need to use the normalized_device_ids (starting from index 0) when lowering the program into the HLO, whereas we can still use the denormalized/global_device_ids in other places such as inside pjrt client to set the device_assignment using the user specified device_ids.

Component diagram for reference -
Image (1)

Ref issue - #9390

@kvshbg-aws kvshbg-aws force-pushed the kvshbg-aws/local-spmd-abstraction branch from d0502ab to 7fc15ea Compare July 10, 2025 18:34
@qihqi qihqi requested review from rpsilva-aws and pgmoka July 11, 2025 04:20
@kvshbg-aws kvshbg-aws force-pushed the kvshbg-aws/local-spmd-abstraction branch 4 times, most recently from 7c4a3cd to 1d55ae9 Compare July 16, 2025 23:57
@kvshbg-aws kvshbg-aws force-pushed the kvshbg-aws/local-spmd-abstraction branch from 2756c1a to 1ddbb1b Compare July 23, 2025 21:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant