Skip to content

Commit

Permalink
use set_value_hook in with_partitioning
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Oct 28, 2023
1 parent 76fb81d commit 491211f
Showing 1 changed file with 10 additions and 13 deletions.
23 changes: 10 additions & 13 deletions flax/experimental/nnx/nnx/spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,31 +195,28 @@ def with_partitioning(
initializer: F,
sharding: Sharding,
mesh: tp.Optional[jax.sharding.Mesh] = None,
set_value_hooks: tp.Union[
variables.SetValueHook[A], tp.Sequence[variables.SetValueHook[A]]
] = (),
get_value_hooks: tp.Union[
variables.GetValueHook[A], tp.Sequence[variables.GetValueHook[A]]
] = (),
create_value_hooks: tp.Union[
variables.CreateValueHook[A], tp.Sequence[variables.CreateValueHook[A]]
] = (),
**metadata: tp.Any,
) -> F:
# TODO: turn this into a create_value_hook
@functools.wraps(initializer)
def maybe_constrain(*args, **kwargs):
y = initializer(*args, **kwargs)
if _global_mesh_defined() or (mesh is not None):
return with_sharding_constraint(y, sharding, mesh=mesh)
return y

if callable(get_value_hooks):
get_value_hooks = (get_value_hooks, sharding_hook)
else:
get_value_hooks = (*get_value_hooks, sharding_hook)

if callable(create_value_hooks):
create_value_hooks = (create_value_hooks, sharding_hook)
else:
create_value_hooks = (*create_value_hooks, sharding_hook)

return variables.with_metadata(
tp.cast(F, maybe_constrain),
set_value_hooks=set_value_hooks,
initializer,
get_value_hooks=get_value_hooks,
create_value_hooks=create_value_hooks,
sharding=sharding,
mesh=mesh,
**metadata,
Expand Down

0 comments on commit 491211f

Please sign in to comment.