Skip to content

[torchax]: JittableModule statedict handling #9195

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

zmelumian972
Copy link
Contributor

torchax aims to improve seamless interoperability between torch and jax

one of the parts in torch training pipeline revolves around storing and loading statedict (checkpoints)

most of the objects revolving torch checkpoints expect a (non nested) dict containing weight name and it's value (in either CPU or GPU device)

since torchax tensors are held in jax container, torch checkpointers cannot easily handle it

this changes forces JittableModule to convert state_dict() functions (both load and get), making it seamless to the user when he wants to extract the statdict prior to saving it as a checkpoint

@zmelumian972 zmelumian972 changed the title Torchax: JittableModule statedict handling [Torchax]: JittableModule statedict handling May 19, 2025
@zmelumian972 zmelumian972 changed the title [Torchax]: JittableModule statedict handling [torchax]: JittableModule statedict handling May 19, 2025
@qihqi qihqi self-requested a review May 19, 2025 18:33
@zmelumian972 zmelumian972 force-pushed the torchax/statedict_ascpu branch from 43404e5 to 8b9bf46 Compare May 20, 2025 11:15
@qihqi qihqi requested review from qihqi May 20, 2025 22:29
Copy link
Collaborator

@qihqi qihqi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Please fix the lint with

yapf -i -r *.py test/ scripts/ torch_xla/ benchmarks/ torchax/

Thanks

@zmelumian972 zmelumian972 force-pushed the torchax/statedict_ascpu branch from 8b9bf46 to 121cc00 Compare May 21, 2025 06:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants