Skip to content

Commit 8b9bf46

Browse files
committed
[torchax] Added JittableModule::load_state_dict mechanism
1 parent 0c045cb commit 8b9bf46

File tree

2 files changed

+110
-7
lines changed

2 files changed

+110
-7
lines changed

torchax/test/test_statedict.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,45 @@ def __init__(self):
1818

1919
def forward(self, x):
2020
return self.linear(x)
21-
21+
22+
mesh = mesh_util.Mesh.fsdp_mesh()
23+
model = interop.JittableModule(mesh.initialize_model_sharded(Model, ()))
2224

2325
class TestTensorStateDict(unittest.TestCase):
24-
def test_load_statedict(self):
25-
mesh = mesh_util.Mesh.fsdp_mesh()
26-
model = mesh.initialize_model_sharded(Model, ())
27-
model = interop.JittableModule(model)
28-
state_dict = model.cpu_state_dict()
26+
27+
def test_get_statedict(self):
28+
state_dict_cpu = model.cpu_state_dict()
2929
is_xla_tensor = pytree.tree_map(
3030
lambda t: isinstance(t, tensor.Tensor),
31-
state_dict
31+
state_dict_cpu
3232
)
3333
assert not any(is_xla_tensor.values()), "State dict should not contain XLA tensors"
34+
35+
def test_load_statedict(self):
36+
state_dict_cpu = model.cpu_state_dict()
37+
state_dict_cpu = pytree.tree_map(
38+
torch.zeros_like,
39+
state_dict_cpu
40+
)
41+
model.load_state_dict(state_dict_cpu)
42+
is_zeros = pytree.tree_map(
43+
lambda t: torch.equal(t, torch.zeros_like(t)),
44+
state_dict_cpu
45+
)
46+
assert all(is_zeros.values()), "State dict should be zeros"
47+
48+
def test_load_statedict_partial(self):
49+
state_dict_cpu = model.cpu_state_dict()
50+
del state_dict_cpu['_model.linear.bias']
51+
state_dict_cpu = pytree.tree_map(
52+
torch.ones_like,
53+
state_dict_cpu
54+
)
55+
key_check = model.load_state_dict(state_dict_cpu, strict=False)
56+
assert key_check.missing_keys == ['_model.linear.bias'], "Missing keys should be '_model.linear.bias'"
57+
linear_weight = model.state_dict()['_model.linear.weight']
58+
assert torch.equal(linear_weight, torch.ones_like(linear_weight)), "Linear weight should be ones"
59+
3460

3561
if __name__ == '__main__':
3662
unittest.main()

torchax/torchax/interop.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Mapping, Any
12
import collections
23
import copy
34
import functools
@@ -127,12 +128,88 @@ def call(*args, **kwargs):
127128
self._jitted[key] = call
128129

129130
def cpu_state_dict(self, *args, **kwargs):
131+
"""
132+
Wrapper for state_dict
133+
134+
this function will make sure to transfer all the parameters to CPU
135+
making it easier to save the state dict with torch.save
136+
137+
Returns:
138+
Mapping[str, Any]: A mapping of parameter names to their values (in torch CPU)
139+
"""
130140
state_dict = super().state_dict(*args, **kwargs)
131141
state_dict = pytree.tree_map(
132142
lambda t: t.cpu(),
133143
state_dict
134144
)
135145
return state_dict
146+
147+
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False):
148+
"""
149+
Wrapper for load_state_dict
150+
151+
This function assumes torch CPU state dict and will transfer the parameters to the correct device
152+
and dtype before loading them into the model.
153+
154+
Args:
155+
state_dict (Mapping[str, Any]): A mapping of parameter names to their values (in torch CPU)
156+
strict (bool, optional): whether to strictly enforce that the keys
157+
in :attr:`state_dict` match the keys returned by this module's
158+
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
159+
assign (bool, optional): When set to ``False``, the properties of the tensors
160+
in the current module are preserved whereas setting it to ``True`` preserves
161+
properties of the Tensors in the state dict. The only
162+
exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s
163+
for which the value from the module is preserved.
164+
Default: ``False``
165+
166+
Returns:
167+
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
168+
* **missing_keys** is a list of str containing any keys that are expected
169+
by this module but missing from the provided ``state_dict``.
170+
* **unexpected_keys** is a list of str containing the keys that are not
171+
expected by this module but present in the provided ``state_dict``.
172+
"""
173+
# Move tensors to JAX to have easier time extracting sharding information
174+
current_state_dict = super().state_dict()
175+
current_state_dict = jax_view(current_state_dict)
176+
177+
# create out shardings that eithe reuses the current state dict sharding or replicates the weights
178+
def extract_sharding_or_replicate(name):
179+
if name in current_state_dict:
180+
return current_state_dict[name].sharding
181+
return jax.sharding.PartitionSpec()
182+
183+
output_shards = {
184+
name: extract_sharding_or_replicate(name)
185+
for name in state_dict
186+
}
187+
188+
def convert_to_xla_tensor_if_needed(t):
189+
is_torch_tensor = isinstance(t, torch.Tensor)
190+
is_xla_tensor = isinstance(t, torchax.tensor.Tensor)
191+
if is_xla_tensor:
192+
t = jax_view(t)
193+
elif is_torch_tensor:
194+
# convert to jax tensor
195+
t = tensor.t2j(t)
196+
return t
197+
198+
# convert the state dict to JAX and shard them
199+
state_dict = pytree.tree_map(
200+
tensor.t2j,
201+
state_dict,
202+
)
203+
# Convert ordered dict to regular dict, pjit type-safety checks
204+
state_dict = dict(state_dict)
205+
jitted = jax_jit(
206+
lambda t: t, kwargs_for_jax_jit={"out_shardings": output_shards}
207+
)
208+
state_dict = jitted(state_dict)
209+
# review it as torch tensors, so we can use torch.assign if we need to
210+
state_dict = torch_view(state_dict)
211+
212+
return super().load_state_dict(state_dict, strict, assign)
136213

137214

138215
class CompileMixin:

0 commit comments

Comments
 (0)