Skip to content

Implement XLAShardedTensor._spec and test #9488

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

aws-cph
Copy link
Contributor

@aws-cph aws-cph commented Jul 17, 2025

Implementing and adding tests for XLAShardedTensor._spec in regards to #9418.

mesh = DeviceMesh("xla",
torch.tensor(device_list).reshape(self.mesh_shape))
else:
# default to 1D mesh
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this should be an error.

placements.append(
Shard(tensor_dim) if tensor_dim is not None else Replicate())
else:
placements = [Replicate()]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here. Should be an error.

@@ -651,7 +651,9 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
op_sharding = mesh.get_op_sharding(partition_spec)
annotate_func = torch_xla._XLAC._xla_mark_sharding
annotate_func(unwrap_sharded_tensor(t), op_sharding)
return wrap_as_sharded_tensor(t)
# Pass mesh and partition spec information for DTensor compatibility
return wrap_as_sharded_tensor(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do the same for the other APIs above, like annotate_custom_sharding and enable_manual_sharding?


# use existing mesh_shape
if hasattr(self, 'mesh_shape') and self.mesh_shape:
import torch_xla.runtime as xr

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: move the import outside if

mesh = DeviceMesh("xla",
torch.tensor(device_list).reshape(self.mesh_shape))
else:
# default to 1D mesh
Copy link

@fhaolinaws fhaolinaws Jul 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have default instead of throwing error? Is it for auto wrapping?

from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor.placement_types import Shard, Replicate

# use existing mesh_shape

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to extract the conversion into a function and put it around here

def convert_to_xla_mesh(dt_mesh: DeviceMesh) -> "Mesh":

import torch_xla.runtime as xr
device_count = xr.global_runtime_device_count()
device_list = list(range(device_count))
mesh = DeviceMesh("xla",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to take care of mesh dim names, too

device_count = xr.global_runtime_device_count()
mesh = DeviceMesh("xla", list(range(device_count)))

# use existing partition_spec

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

converted_spec = xla_tensor._spec

assert converted_spec.mesh.device_type == "xla"
assert converted_spec.mesh.size() == device_count

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May need to assert on mesh size each dim, too

@aws-cph aws-cph force-pushed the aws-cph_dtensor_spec branch from 28c4f6f to 78a5639 Compare July 18, 2025 22:05
@aws-cph aws-cph force-pushed the aws-cph_dtensor_spec branch from 78a5639 to 4dd06ad Compare July 18, 2025 22:08
Comment on lines +211 to +213
# Return cached spec if available
if hasattr(self, '_cached_spec'):
return self._cached_spec
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if a call to wrap_as_sharded_tensor changes self.mesh_shape and/or self.partition_spec? Will you still get this cached value even though it's out of date?

return self._cached_spec

# use existing mesh_shape
if hasattr(self, 'mesh_shape') and self.mesh_shape:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can self.mesh_shape and self.partition_spec be set to None at initialization? That removes the need to check hasattr. I don't think there's a semantic difference between the attribute not existing and its value being None, is there?

Comment on lines +182 to +192
def find_sharded_info(x):
nonlocal mesh_shape, partition_spec
if isinstance(x, XLAShardedTensor):
if hasattr(x, 'mesh_shape') and x.mesh_shape:
mesh_shape = x.mesh_shape
if hasattr(x, 'partition_spec') and x.partition_spec:
partition_spec = x.partition_spec

tree_map(find_sharded_info, args)
if kwargs:
tree_map(find_sharded_info, kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have enough experience with this codebase to understand the context. What are *args and **kwargs in practice and why do we expect they would have sharding information that is relevant for elem?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, is this "elem is not an XLAShardedTensor but there exists sharding information we want to acquire" path tested?

Comment on lines +89 to +91
assert second_access_time * 10 < first_access_time, \
f"Cached access should be much faster: {first_access_time:.6f}s vs {second_access_time:.6f}s"

Copy link
Collaborator

@bfolie bfolie Jul 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These sorts of tests that rely on the wall clock often lead to annoying flakes in my experience. I think it's sufficient to just test that self._cached_spec has a permanent value after the first call. If you really want to assert that a certain code path is called you could do something with mocks, but that seems like overkill for this, which is simple to confirm by looking at the code ("if the attribute exists, return it" is the first line of the property getter)

return XLAShardedTensor(t)
return t
return XLAShardedTensor(
t, mesh_shape=mesh_shape, partition_spec=partition_spec)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fact that calling wrap_as_sharded_tensor in an XLAShardedTensor now does something other than trivially returning t is logically new. Is this resharding logic tested?

Convert XLA sharding information to DTensorSpec for DTensor interface compatibility.
"""
# Return cached spec if available
if hasattr(self, '_cached_spec'):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we assign self._cache_spec = None in the constructor; then check self._cached_spec is None here.

@@ -91,10 +95,15 @@ class XLAShardedTensor(torch.Tensor):
# >> assert len(input.shape) == len(partition_spec)
partition_spec: Tuple[int, None]

__slots__ = ['global_tensor']
__slots__ = ['global_tensor', 'mesh_shape', 'partition_spec', '_cached_spec']
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we get rid of slots? is it just for performance?

if hasattr(x, 'partition_spec') and x.partition_spec:
partition_spec = x.partition_spec

tree_map(find_sharded_info, args)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can do tree_map_only(type, callable, args) then you can skip the isinstance check inside of teh callable

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants