Skip to content

Commit

Permalink
Allow nnx.bridge.variables.nnx_attrs_to_linen_vars take `nnx.Variab…
Browse files Browse the repository at this point in the history
…leState` as argument.

PiperOrigin-RevId: 713442378
  • Loading branch information
IvyZX authored and Flax Authors committed Jan 8, 2025
1 parent e2134af commit 22f3c1d
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions flax/nnx/bridge/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def _recursive_merge(dict1, dict2):


def linen_vars_to_nnx_attrs(variables: tp.Mapping[str, Any]) -> dict[str, Any]:
"""Convert a dict of Linen-style variables to NNX variables."""
nnx_vars = jax.tree_util.tree_map_with_path(
lambda kp, x: to_nnx_var(get_col_name(kp), x),
variables, is_leaf=lambda x: isinstance(x, meta.AxisMetadata))
Expand All @@ -190,19 +191,22 @@ def linen_vars_to_nnx_attrs(variables: tp.Mapping[str, Any]) -> dict[str, Any]:


def nnx_attrs_to_linen_vars(nnx_attrs: dict) -> dict:
"""Convert a dict of NNX variables (or variable states) to Linen-style variables."""
linen_structured = {}
for kp, v in traversals.flatten_mapping(
nnx_attrs,
is_leaf=lambda _, x: isinstance(x, variableslib.Variable | GraphDef),
).items():
is_leaf=lambda _, x: isinstance(
x, variableslib.Variable | variableslib.VariableState | GraphDef
),
).items():
if isinstance(v, variableslib.Variable):
col_name = variable_type_name(type(v))
v = to_linen_var(v.to_state())
elif isinstance(v, variableslib.VariableState):
col_name = variable_type_name(v.type)
v = to_linen_var(v)
else:
col_name = 'nnx' # it must be an nnx.GraphDef, for some ToLinen submodule
linen_structured[(col_name, *kp)] = v
variables = traversals.unflatten_mapping(linen_structured)
variables = jax.tree.map(lambda x: to_linen_var(x.to_state()),
variables,
is_leaf=lambda x: isinstance(x, variableslib.Variable))
return variables

0 comments on commit 22f3c1d

Please sign in to comment.