|
| 1 | +from typing import Mapping, Any |
1 | 2 | import collections
|
2 | 3 | import copy
|
3 | 4 | import functools
|
@@ -125,14 +126,88 @@ def call(*args, **kwargs):
|
125 | 126 | return jitted(self.params, self.buffers, *args, **kwargs)
|
126 | 127 |
|
127 | 128 | self._jitted[key] = call
|
128 |
| - |
| 129 | + |
129 | 130 | def cpu_state_dict(self, *args, **kwargs):
|
| 131 | + """ |
| 132 | + Wrapper for state_dict |
| 133 | + |
| 134 | + this function will make sure to transfer all the parameters to CPU |
| 135 | + making it easier to save the state dict with torch.save |
| 136 | +
|
| 137 | + Returns: |
| 138 | + Mapping[str, Any]: A mapping of parameter names to their values (in torch CPU) |
| 139 | + """ |
130 | 140 | state_dict = super().state_dict(*args, **kwargs)
|
| 141 | + state_dict = pytree.tree_map(lambda t: t.cpu(), state_dict) |
| 142 | + return state_dict |
| 143 | + |
| 144 | + def load_state_dict(self, |
| 145 | + state_dict: Mapping[str, Any], |
| 146 | + strict: bool = True, |
| 147 | + assign: bool = False): |
| 148 | + """ |
| 149 | + Wrapper for load_state_dict |
| 150 | + |
| 151 | + This function assumes torch CPU state dict and will transfer the parameters to the correct device |
| 152 | + and dtype before loading them into the model. |
| 153 | +
|
| 154 | + Args: |
| 155 | + state_dict (Mapping[str, Any]): A mapping of parameter names to their values (in torch CPU) |
| 156 | + strict (bool, optional): whether to strictly enforce that the keys |
| 157 | + in :attr:`state_dict` match the keys returned by this module's |
| 158 | + :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` |
| 159 | + assign (bool, optional): When set to ``False``, the properties of the tensors |
| 160 | + in the current module are preserved whereas setting it to ``True`` preserves |
| 161 | + properties of the Tensors in the state dict. The only |
| 162 | + exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s |
| 163 | + for which the value from the module is preserved. |
| 164 | + Default: ``False`` |
| 165 | +
|
| 166 | + Returns: |
| 167 | + ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: |
| 168 | + * **missing_keys** is a list of str containing any keys that are expected |
| 169 | + by this module but missing from the provided ``state_dict``. |
| 170 | + * **unexpected_keys** is a list of str containing the keys that are not |
| 171 | + expected by this module but present in the provided ``state_dict``. |
| 172 | + """ |
| 173 | + # Move tensors to JAX to have easier time extracting sharding information |
| 174 | + current_state_dict = super().state_dict() |
| 175 | + current_state_dict = jax_view(current_state_dict) |
| 176 | + |
| 177 | + # create out shardings that eithe reuses the current state dict sharding or replicates the weights |
| 178 | + def extract_sharding_or_replicate(name): |
| 179 | + if name in current_state_dict: |
| 180 | + return current_state_dict[name].sharding |
| 181 | + return jax.sharding.PartitionSpec() |
| 182 | + |
| 183 | + output_shards = { |
| 184 | + name: extract_sharding_or_replicate(name) for name in state_dict |
| 185 | + } |
| 186 | + |
| 187 | + def convert_to_xla_tensor_if_needed(t): |
| 188 | + is_torch_tensor = isinstance(t, torch.Tensor) |
| 189 | + is_xla_tensor = isinstance(t, torchax.tensor.Tensor) |
| 190 | + if is_xla_tensor: |
| 191 | + t = jax_view(t) |
| 192 | + elif is_torch_tensor: |
| 193 | + # convert to jax tensor |
| 194 | + t = tensor.t2j(t) |
| 195 | + return t |
| 196 | + |
| 197 | + # convert the state dict to JAX and shard them |
131 | 198 | state_dict = pytree.tree_map(
|
132 |
| - lambda t: t.cpu(), |
133 |
| - state_dict |
| 199 | + tensor.t2j, |
| 200 | + state_dict, |
134 | 201 | )
|
135 |
| - return state_dict |
| 202 | + # Convert ordered dict to regular dict, pjit type-safety checks |
| 203 | + state_dict = dict(state_dict) |
| 204 | + jitted = jax_jit( |
| 205 | + lambda t: t, kwargs_for_jax_jit={"out_shardings": output_shards}) |
| 206 | + state_dict = jitted(state_dict) |
| 207 | + # review it as torch tensors, so we can use torch.assign if we need to |
| 208 | + state_dict = torch_view(state_dict) |
| 209 | + |
| 210 | + return super().load_state_dict(state_dict, strict, assign) |
136 | 211 |
|
137 | 212 |
|
138 | 213 | class CompileMixin:
|
|
0 commit comments