diff --git a/flax/core/meta.py b/flax/core/meta.py index 278e5b51a0..b6266462f9 100644 --- a/flax/core/meta.py +++ b/flax/core/meta.py @@ -333,26 +333,33 @@ def wrapper(*args, **kwargs): return wrapper -def get_partition_spec(tree: Any) -> Any: - """Extracts a PartitionSpec tree from a PyTree containing ``Partitioned`` values.""" +def _get_leaf_pspec(x: Any) -> jax.sharding.PartitionSpec | None: + if hasattr(x, 'get_partition_spec'): + return x.get_partition_spec() + # Unboxed arrays, which should be replicated across all devices + elif hasattr(x, 'shape'): + return jax.sharding.PartitionSpec() + else: + return None - def f(x): - if hasattr(x, 'get_partition_spec'): - return x.get_partition_spec() - # Unboxed arrays, which should be replicated across all devices - elif hasattr(x, 'shape'): - return jax.sharding.PartitionSpec() - else: - return None +def get_partition_spec(tree: Any) -> Any: + """Extracts a PartitionSpec tree from a PyTree containing ``Partitioned`` values.""" return jax.tree_util.tree_map( - f, tree, is_leaf=lambda x: isinstance(x, AxisMetadata) + _get_leaf_pspec, tree, is_leaf=lambda x: isinstance(x, AxisMetadata) ) def get_sharding(tree: Any, mesh: jax.sharding.Mesh) -> Any: """Extracts a jax.sharding tree from a PyTree containing ``Partitioned`` values and a mesh.""" - pspec_tree = get_partition_spec(tree) + def f(x: Any) -> jax.sharding.Sharding | None: + if hasattr(x, 'get_sharding'): + return x.get_sharding(mesh) + pspec = _get_leaf_pspec(x) + if pspec is None: + return None + return jax.sharding.NamedSharding(mesh, pspec) + return jax.tree_util.tree_map( - lambda x: jax.sharding.NamedSharding(mesh, x), pspec_tree + f, tree, is_leaf=lambda x: isinstance(x, AxisMetadata) )