Skip to content

Commit 121cc00

Browse files
committed
[torchax] Added JittableModule::load_state_dict mechanism
1 parent 0c045cb commit 121cc00

File tree

2 files changed

+121
-28
lines changed

2 files changed

+121
-28
lines changed

torchax/test/test_statedict.py

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,54 @@
1-
2-
31
import unittest
42
import torch
53
from torch.utils import _pytree as pytree
64

7-
from torchax import (
8-
interop,
9-
mesh_util,
10-
tensor
11-
)
5+
from torchax import (interop, mesh_util, tensor)
126

137

148
class Model(torch.nn.Module):
15-
def __init__(self):
16-
super(Model, self).__init__()
17-
self.linear = torch.nn.Linear(10, 5)
189

19-
def forward(self, x):
20-
return self.linear(x)
21-
10+
def __init__(self):
11+
super(Model, self).__init__()
12+
self.linear = torch.nn.Linear(10, 5)
13+
14+
def forward(self, x):
15+
return self.linear(x)
16+
17+
18+
mesh = mesh_util.Mesh.fsdp_mesh()
19+
model = interop.JittableModule(mesh.initialize_model_sharded(Model, ()))
20+
2221

2322
class TestTensorStateDict(unittest.TestCase):
24-
def test_load_statedict(self):
25-
mesh = mesh_util.Mesh.fsdp_mesh()
26-
model = mesh.initialize_model_sharded(Model, ())
27-
model = interop.JittableModule(model)
28-
state_dict = model.cpu_state_dict()
29-
is_xla_tensor = pytree.tree_map(
30-
lambda t: isinstance(t, tensor.Tensor),
31-
state_dict
32-
)
33-
assert not any(is_xla_tensor.values()), "State dict should not contain XLA tensors"
23+
24+
def test_get_statedict(self):
25+
state_dict_cpu = model.cpu_state_dict()
26+
is_xla_tensor = pytree.tree_map(lambda t: isinstance(t, tensor.Tensor),
27+
state_dict_cpu)
28+
assert not any(
29+
is_xla_tensor.values()), "State dict should not contain XLA tensors"
30+
31+
def test_load_statedict(self):
32+
state_dict_cpu = model.cpu_state_dict()
33+
state_dict_cpu = pytree.tree_map(torch.zeros_like, state_dict_cpu)
34+
model.load_state_dict(state_dict_cpu)
35+
is_zeros = pytree.tree_map(lambda t: torch.equal(t, torch.zeros_like(t)),
36+
state_dict_cpu)
37+
assert all(is_zeros.values()), "State dict should be zeros"
38+
39+
def test_load_statedict_partial(self):
40+
state_dict_cpu = model.cpu_state_dict()
41+
del state_dict_cpu['_model.linear.bias']
42+
state_dict_cpu = pytree.tree_map(torch.ones_like, state_dict_cpu)
43+
key_check = model.load_state_dict(state_dict_cpu, strict=False)
44+
assert key_check.missing_keys == [
45+
'_model.linear.bias'
46+
], "Missing keys should be '_model.linear.bias'"
47+
linear_weight = model.state_dict()['_model.linear.weight']
48+
assert torch.equal(
49+
linear_weight,
50+
torch.ones_like(linear_weight)), "Linear weight should be ones"
51+
3452

3553
if __name__ == '__main__':
36-
unittest.main()
54+
unittest.main()

torchax/torchax/interop.py

Lines changed: 79 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Mapping, Any
12
import collections
23
import copy
34
import functools
@@ -125,14 +126,88 @@ def call(*args, **kwargs):
125126
return jitted(self.params, self.buffers, *args, **kwargs)
126127

127128
self._jitted[key] = call
128-
129+
129130
def cpu_state_dict(self, *args, **kwargs):
131+
"""
132+
Wrapper for state_dict
133+
134+
this function will make sure to transfer all the parameters to CPU
135+
making it easier to save the state dict with torch.save
136+
137+
Returns:
138+
Mapping[str, Any]: A mapping of parameter names to their values (in torch CPU)
139+
"""
130140
state_dict = super().state_dict(*args, **kwargs)
141+
state_dict = pytree.tree_map(lambda t: t.cpu(), state_dict)
142+
return state_dict
143+
144+
def load_state_dict(self,
145+
state_dict: Mapping[str, Any],
146+
strict: bool = True,
147+
assign: bool = False):
148+
"""
149+
Wrapper for load_state_dict
150+
151+
This function assumes torch CPU state dict and will transfer the parameters to the correct device
152+
and dtype before loading them into the model.
153+
154+
Args:
155+
state_dict (Mapping[str, Any]): A mapping of parameter names to their values (in torch CPU)
156+
strict (bool, optional): whether to strictly enforce that the keys
157+
in :attr:`state_dict` match the keys returned by this module's
158+
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
159+
assign (bool, optional): When set to ``False``, the properties of the tensors
160+
in the current module are preserved whereas setting it to ``True`` preserves
161+
properties of the Tensors in the state dict. The only
162+
exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s
163+
for which the value from the module is preserved.
164+
Default: ``False``
165+
166+
Returns:
167+
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
168+
* **missing_keys** is a list of str containing any keys that are expected
169+
by this module but missing from the provided ``state_dict``.
170+
* **unexpected_keys** is a list of str containing the keys that are not
171+
expected by this module but present in the provided ``state_dict``.
172+
"""
173+
# Move tensors to JAX to have easier time extracting sharding information
174+
current_state_dict = super().state_dict()
175+
current_state_dict = jax_view(current_state_dict)
176+
177+
# create out shardings that eithe reuses the current state dict sharding or replicates the weights
178+
def extract_sharding_or_replicate(name):
179+
if name in current_state_dict:
180+
return current_state_dict[name].sharding
181+
return jax.sharding.PartitionSpec()
182+
183+
output_shards = {
184+
name: extract_sharding_or_replicate(name) for name in state_dict
185+
}
186+
187+
def convert_to_xla_tensor_if_needed(t):
188+
is_torch_tensor = isinstance(t, torch.Tensor)
189+
is_xla_tensor = isinstance(t, torchax.tensor.Tensor)
190+
if is_xla_tensor:
191+
t = jax_view(t)
192+
elif is_torch_tensor:
193+
# convert to jax tensor
194+
t = tensor.t2j(t)
195+
return t
196+
197+
# convert the state dict to JAX and shard them
131198
state_dict = pytree.tree_map(
132-
lambda t: t.cpu(),
133-
state_dict
199+
tensor.t2j,
200+
state_dict,
134201
)
135-
return state_dict
202+
# Convert ordered dict to regular dict, pjit type-safety checks
203+
state_dict = dict(state_dict)
204+
jitted = jax_jit(
205+
lambda t: t, kwargs_for_jax_jit={"out_shardings": output_shards})
206+
state_dict = jitted(state_dict)
207+
# review it as torch tensors, so we can use torch.assign if we need to
208+
state_dict = torch_view(state_dict)
209+
210+
return super().load_state_dict(state_dict, strict, assign)
136211

137212

138213
class CompileMixin:

0 commit comments

Comments
 (0)