-
Notifications
You must be signed in to change notification settings - Fork 553
Implement XLAShardedTensor._spec and test #9488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
bb4eb3b
to
28c4f6f
Compare
mesh = DeviceMesh("xla", | ||
torch.tensor(device_list).reshape(self.mesh_shape)) | ||
else: | ||
# default to 1D mesh |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this should be an error.
placements.append( | ||
Shard(tensor_dim) if tensor_dim is not None else Replicate()) | ||
else: | ||
placements = [Replicate()] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here. Should be an error.
@@ -651,7 +651,9 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, | |||
op_sharding = mesh.get_op_sharding(partition_spec) | |||
annotate_func = torch_xla._XLAC._xla_mark_sharding | |||
annotate_func(unwrap_sharded_tensor(t), op_sharding) | |||
return wrap_as_sharded_tensor(t) | |||
# Pass mesh and partition spec information for DTensor compatibility | |||
return wrap_as_sharded_tensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's do the same for the other APIs above, like annotate_custom_sharding and enable_manual_sharding?
|
||
# use existing mesh_shape | ||
if hasattr(self, 'mesh_shape') and self.mesh_shape: | ||
import torch_xla.runtime as xr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: move the import outside if
mesh = DeviceMesh("xla", | ||
torch.tensor(device_list).reshape(self.mesh_shape)) | ||
else: | ||
# default to 1D mesh |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we have default instead of throwing error? Is it for auto wrapping?
from torch.distributed.device_mesh import DeviceMesh | ||
from torch.distributed.tensor.placement_types import Shard, Replicate | ||
|
||
# use existing mesh_shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's better to extract the conversion into a function and put it around here
xla/torch_xla/distributed/spmd/api.py
Line 49 in e7dcc7b
def convert_to_xla_mesh(dt_mesh: DeviceMesh) -> "Mesh": |
import torch_xla.runtime as xr | ||
device_count = xr.global_runtime_device_count() | ||
device_list = list(range(device_count)) | ||
mesh = DeviceMesh("xla", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to take care of mesh dim names, too
device_count = xr.global_runtime_device_count() | ||
mesh = DeviceMesh("xla", list(range(device_count))) | ||
|
||
# use existing partition_spec |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above
converted_spec = xla_tensor._spec | ||
|
||
assert converted_spec.mesh.device_type == "xla" | ||
assert converted_spec.mesh.size() == device_count |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May need to assert on mesh size each dim, too
28c4f6f
to
78a5639
Compare
78a5639
to
4dd06ad
Compare
# Return cached spec if available | ||
if hasattr(self, '_cached_spec'): | ||
return self._cached_spec |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if a call to wrap_as_sharded_tensor
changes self.mesh_shape
and/or self.partition_spec
? Will you still get this cached value even though it's out of date?
return self._cached_spec | ||
|
||
# use existing mesh_shape | ||
if hasattr(self, 'mesh_shape') and self.mesh_shape: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can self.mesh_shape
and self.partition_spec
be set to None at initialization? That removes the need to check hasattr
. I don't think there's a semantic difference between the attribute not existing and its value being None, is there?
def find_sharded_info(x): | ||
nonlocal mesh_shape, partition_spec | ||
if isinstance(x, XLAShardedTensor): | ||
if hasattr(x, 'mesh_shape') and x.mesh_shape: | ||
mesh_shape = x.mesh_shape | ||
if hasattr(x, 'partition_spec') and x.partition_spec: | ||
partition_spec = x.partition_spec | ||
|
||
tree_map(find_sharded_info, args) | ||
if kwargs: | ||
tree_map(find_sharded_info, kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have enough experience with this codebase to understand the context. What are *args and **kwargs in practice and why do we expect they would have sharding information that is relevant for elem
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, is this "elem
is not an XLAShardedTensor
but there exists sharding information we want to acquire" path tested?
assert second_access_time * 10 < first_access_time, \ | ||
f"Cached access should be much faster: {first_access_time:.6f}s vs {second_access_time:.6f}s" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These sorts of tests that rely on the wall clock often lead to annoying flakes in my experience. I think it's sufficient to just test that self._cached_spec
has a permanent value after the first call. If you really want to assert that a certain code path is called you could do something with mocks, but that seems like overkill for this, which is simple to confirm by looking at the code ("if the attribute exists, return it" is the first line of the property getter)
return XLAShardedTensor(t) | ||
return t | ||
return XLAShardedTensor( | ||
t, mesh_shape=mesh_shape, partition_spec=partition_spec) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fact that calling wrap_as_sharded_tensor
in an XLAShardedTensor
now does something other than trivially returning t
is logically new. Is this resharding logic tested?
Convert XLA sharding information to DTensorSpec for DTensor interface compatibility. | ||
""" | ||
# Return cached spec if available | ||
if hasattr(self, '_cached_spec'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we assign self._cache_spec = None in the constructor; then check self._cached_spec is None here.
@@ -91,10 +95,15 @@ class XLAShardedTensor(torch.Tensor): | |||
# >> assert len(input.shape) == len(partition_spec) | |||
partition_spec: Tuple[int, None] | |||
|
|||
__slots__ = ['global_tensor'] | |||
__slots__ = ['global_tensor', 'mesh_shape', 'partition_spec', '_cached_spec'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we get rid of slots? is it just for performance?
if hasattr(x, 'partition_spec') and x.partition_spec: | ||
partition_spec = x.partition_spec | ||
|
||
tree_map(find_sharded_info, args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can do tree_map_only(type, callable, args) then you can skip the isinstance check inside of teh callable
Implementing and adding tests for XLAShardedTensor._spec in regards to #9418.