Skip to content

Commit 0c045cb

Browse files
committed
[torchax] Support for JittableModule::state_dict()
1 parent edc1a88 commit 0c045cb

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

torchax/test/test_statedict.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
2+
3+
import unittest
4+
import torch
5+
from torch.utils import _pytree as pytree
6+
7+
from torchax import (
8+
interop,
9+
mesh_util,
10+
tensor
11+
)
12+
13+
14+
class Model(torch.nn.Module):
15+
def __init__(self):
16+
super(Model, self).__init__()
17+
self.linear = torch.nn.Linear(10, 5)
18+
19+
def forward(self, x):
20+
return self.linear(x)
21+
22+
23+
class TestTensorStateDict(unittest.TestCase):
24+
def test_load_statedict(self):
25+
mesh = mesh_util.Mesh.fsdp_mesh()
26+
model = mesh.initialize_model_sharded(Model, ())
27+
model = interop.JittableModule(model)
28+
state_dict = model.cpu_state_dict()
29+
is_xla_tensor = pytree.tree_map(
30+
lambda t: isinstance(t, tensor.Tensor),
31+
state_dict
32+
)
33+
assert not any(is_xla_tensor.values()), "State dict should not contain XLA tensors"
34+
35+
if __name__ == '__main__':
36+
unittest.main()

torchax/torchax/interop.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,14 @@ def call(*args, **kwargs):
125125
return jitted(self.params, self.buffers, *args, **kwargs)
126126

127127
self._jitted[key] = call
128+
129+
def cpu_state_dict(self, *args, **kwargs):
130+
state_dict = super().state_dict(*args, **kwargs)
131+
state_dict = pytree.tree_map(
132+
lambda t: t.cpu(),
133+
state_dict
134+
)
135+
return state_dict
128136

129137

130138
class CompileMixin:

0 commit comments

Comments
 (0)