diff --git a/flax/experimental/nnx/nnx/spmd.py b/flax/experimental/nnx/nnx/spmd.py index 63abd405bc..57ca6dbd06 100644 --- a/flax/experimental/nnx/nnx/spmd.py +++ b/flax/experimental/nnx/nnx/spmd.py @@ -195,31 +195,28 @@ def with_partitioning( initializer: F, sharding: Sharding, mesh: tp.Optional[jax.sharding.Mesh] = None, - set_value_hooks: tp.Union[ - variables.SetValueHook[A], tp.Sequence[variables.SetValueHook[A]] - ] = (), get_value_hooks: tp.Union[ variables.GetValueHook[A], tp.Sequence[variables.GetValueHook[A]] ] = (), + create_value_hooks: tp.Union[ + variables.CreateValueHook[A], tp.Sequence[variables.CreateValueHook[A]] + ] = (), **metadata: tp.Any, ) -> F: - # TODO: turn this into a create_value_hook - @functools.wraps(initializer) - def maybe_constrain(*args, **kwargs): - y = initializer(*args, **kwargs) - if _global_mesh_defined() or (mesh is not None): - return with_sharding_constraint(y, sharding, mesh=mesh) - return y - if callable(get_value_hooks): get_value_hooks = (get_value_hooks, sharding_hook) else: get_value_hooks = (*get_value_hooks, sharding_hook) + if callable(create_value_hooks): + create_value_hooks = (create_value_hooks, sharding_hook) + else: + create_value_hooks = (*create_value_hooks, sharding_hook) + return variables.with_metadata( - tp.cast(F, maybe_constrain), - set_value_hooks=set_value_hooks, + initializer, get_value_hooks=get_value_hooks, + create_value_hooks=create_value_hooks, sharding=sharding, mesh=mesh, **metadata,