Skip to content

Commit

Permalink
Create sharding via Partitioned.get_sharding()
Browse files Browse the repository at this point in the history
This change modifies the global get_sharding() function to call into Partitioned.get_sharding(). This allows subclasses of Partitioned to override the way sharding is created.

PiperOrigin-RevId: 704781933
  • Loading branch information
hhb authored and Flax Authors committed Dec 11, 2024
1 parent 554b690 commit da6e4b7
Showing 1 changed file with 20 additions and 13 deletions.
33 changes: 20 additions & 13 deletions flax/core/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,26 +333,33 @@ def wrapper(*args, **kwargs):
return wrapper


def get_partition_spec(tree: Any) -> Any:
"""Extracts a PartitionSpec tree from a PyTree containing ``Partitioned`` values."""
def _get_leaf_pspec(x: Any) -> jax.sharding.PartitionSpec | None:
if hasattr(x, 'get_partition_spec'):
return x.get_partition_spec()
# Unboxed arrays, which should be replicated across all devices
elif hasattr(x, 'shape'):
return jax.sharding.PartitionSpec()
else:
return None

def f(x):
if hasattr(x, 'get_partition_spec'):
return x.get_partition_spec()
# Unboxed arrays, which should be replicated across all devices
elif hasattr(x, 'shape'):
return jax.sharding.PartitionSpec()
else:
return None

def get_partition_spec(tree: Any) -> Any:
"""Extracts a PartitionSpec tree from a PyTree containing ``Partitioned`` values."""
return jax.tree_util.tree_map(
f, tree, is_leaf=lambda x: isinstance(x, AxisMetadata)
_get_leaf_pspec, tree, is_leaf=lambda x: isinstance(x, AxisMetadata)
)


def get_sharding(tree: Any, mesh: jax.sharding.Mesh) -> Any:
"""Extracts a jax.sharding tree from a PyTree containing ``Partitioned`` values and a mesh."""
pspec_tree = get_partition_spec(tree)
def f(x: Any) -> jax.sharding.Sharding | None:
if hasattr(x, 'get_sharding'):
return x.get_sharding(mesh)
pspec = _get_leaf_pspec(x)
if pspec is None:
return None
return jax.sharding.NamedSharding(mesh, pspec)

return jax.tree_util.tree_map(
lambda x: jax.sharding.NamedSharding(mesh, x), pspec_tree
f, tree, is_leaf=lambda x: isinstance(x, AxisMetadata)
)

0 comments on commit da6e4b7

Please sign in to comment.