Skip to content

Commit 78a5639

Browse files
committed
Implement XLAShardedTensor._spec and test
1 parent 95ba754 commit 78a5639

File tree

6 files changed

+264
-12
lines changed

6 files changed

+264
-12
lines changed

test/neuron/run_tests.sh

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,14 @@ function run_test {
5656
PJRT_DEVICE=NEURON NEURON_NUM_DEVICES=1 run_coverage "$@"
5757
}
5858

59+
function run_test_multi_device {
60+
if ! test_is_selected "$1"; then
61+
return
62+
fi
63+
echo "Running in PjRt runtime: $@"
64+
PJRT_DEVICE=NEURON run_coverage "$@"
65+
}
66+
5967
function run_test_without_functionalization {
6068
if ! test_is_selected "$1"; then
6169
return
@@ -246,7 +254,8 @@ function run_xla_op_tests3 {
246254
run_test "$_TEST_DIR/spmd/test_xla_spmd_python_api_interaction.py"
247255
#run_test "$_TEST_DIR/spmd/test_dtensor_integration.py"
248256
#run_test "$_TEST_DIR/spmd/test_dtensor_integration2.py"
249-
run_test "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py"
257+
run_test_multi_device "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py"
258+
run_test_multi_device "$_TEST_DIR/spmd/test_xla_dtensor_spec_conv.py"
250259
run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py"
251260
#run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py"
252261
run_test "$_TEST_DIR/spmd/test_train_spmd_linear_model.py"

test/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ function run_xla_op_tests3 {
254254
run_test "$_TEST_DIR/spmd/test_dtensor_integration2.py"
255255
run_test_multi_devices_without_func "$_TEST_DIR/spmd/test_dtensor_integration3.py"
256256
run_test_multi_devices "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py"
257+
run_test_multi_devices "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py"
257258
run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py"
258259
run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py"
259260
run_test "$_TEST_DIR/spmd/test_mp_input_sharding.py"
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import os
2+
import sys
3+
4+
import torch
5+
from torch.distributed.tensor import DeviceMesh, Shard, distribute_tensor
6+
7+
import torch_xla
8+
import torch_xla.runtime as xr
9+
10+
import unittest
11+
import test_xla_sharding_base
12+
13+
14+
class XLADTensorSpecConversionTest(test_xla_sharding_base.XlaShardingTest):
15+
16+
@classmethod
17+
def setUpClass(cls):
18+
super().setUpClass()
19+
20+
def test_sample_test_case(self):
21+
world_size = xr.global_runtime_device_count()
22+
mesh = DeviceMesh("xla", torch.arange(world_size))
23+
big_tensor = torch.randn(100000, 88)
24+
my_dtensor = distribute_tensor(big_tensor, mesh, [Shard(0)])
25+
26+
assert my_dtensor._spec.mesh.device_type == mesh.device_type
27+
assert my_dtensor._spec.placements == (Shard(0),)
28+
29+
def test_xla_to_dtensor_spec_conversion(self):
30+
device_count = xr.global_runtime_device_count()
31+
mesh = DeviceMesh("xla", list(range(device_count)))
32+
33+
# Test different sharding patterns
34+
from torch.distributed.tensor.placement_types import Replicate
35+
test_cases = [
36+
(torch.randn(100, 50), [Shard(0)]),
37+
(torch.randn(100, 50), [Shard(1)]),
38+
(torch.randn(100, 50, 25), [Shard(0)]),
39+
(torch.randn(100, 50), [Replicate()]),
40+
]
41+
42+
for tensor, placements in test_cases:
43+
xla_tensor = distribute_tensor(tensor, mesh, placements)
44+
spec = xla_tensor._spec
45+
46+
assert spec is not None
47+
assert spec.mesh.device_type == "xla"
48+
assert spec.tensor_meta.shape == tensor.shape
49+
assert spec.tensor_meta.dtype == tensor.dtype
50+
assert len(spec.placements) >= 1
51+
assert spec.placements == tuple(placements)
52+
53+
def test_mesh_conversion(self):
54+
device_count = xr.global_runtime_device_count()
55+
original_mesh = DeviceMesh("xla", list(range(device_count)))
56+
tensor = torch.randn(50, 50)
57+
xla_tensor = distribute_tensor(tensor, original_mesh, [Shard(0)])
58+
59+
converted_spec = xla_tensor._spec
60+
61+
assert converted_spec.mesh.device_type == "xla"
62+
assert converted_spec.mesh.size() == device_count
63+
64+
def test_spec_caching(self):
65+
"""Test that _spec property caches results for better performance"""
66+
import time
67+
device_count = xr.global_runtime_device_count()
68+
mesh = DeviceMesh("xla", list(range(device_count)))
69+
tensor = torch.randn(1000,
70+
1000) # Large tensor to make spec creation noticeable
71+
xla_tensor = distribute_tensor(tensor, mesh, [Shard(0)])
72+
73+
# first access should create and cache the spec
74+
start_time = time.time()
75+
spec1 = xla_tensor._spec
76+
first_access_time = time.time() - start_time
77+
78+
# should be much faster due to caching
79+
start_time = time.time()
80+
spec2 = xla_tensor._spec
81+
second_access_time = time.time() - start_time
82+
83+
assert spec1 is spec2
84+
print(
85+
f"First access: {first_access_time:.6f}s, Second access: {second_access_time:.6f}s"
86+
)
87+
assert second_access_time * 10 < first_access_time, \
88+
f"Cached access should be much faster: {first_access_time:.6f}s vs {second_access_time:.6f}s"
89+
90+
def _create_test_tensor_and_mesh(self, tensor_shape, mesh_shape, placements):
91+
"""Helper to create tensor and mesh for testing"""
92+
device_count = xr.global_runtime_device_count()
93+
if device_count < max(mesh_shape):
94+
self.skipTest(
95+
f"Need at least {max(mesh_shape)} devices, got {device_count}")
96+
97+
mesh = DeviceMesh("xla", torch.arange(device_count).reshape(mesh_shape))
98+
tensor = torch.randn(*tensor_shape)
99+
return distribute_tensor(tensor, mesh, placements), mesh
100+
101+
def test_multi_dim_sharding_spec(self):
102+
"""Test _spec for multi-dimensional sharding"""
103+
device_count = xr.global_runtime_device_count()
104+
if device_count < 4:
105+
self.skipTest("Need at least 4 devices for 2D mesh")
106+
107+
mesh_shape = (2, device_count // 2)
108+
xla_tensor, mesh = self._create_test_tensor_and_mesh(
109+
(100, 50), mesh_shape, [Shard(0), Shard(1)])
110+
spec = xla_tensor._spec
111+
112+
assert len(spec.placements) == 2
113+
assert spec.mesh.ndim == 2
114+
115+
def test_tensor_operations_preserve_spec(self):
116+
"""Test that tensor operations preserve sharding metadata"""
117+
xla_tensor, mesh = self._create_test_tensor_and_mesh((100, 50), (-1,),
118+
[Shard(0)])
119+
120+
result_add = xla_tensor + 1
121+
result_mul = xla_tensor * 2
122+
result_relu = torch.relu(xla_tensor)
123+
124+
for result in [result_add, result_mul, result_relu]:
125+
assert hasattr(result, '_spec')
126+
assert result._spec.mesh.device_type == "xla"
127+
128+
def test_mixed_placement_spec(self):
129+
"""Test _spec for tensors with mixed shard/replicate placements"""
130+
from torch.distributed.tensor.placement_types import Replicate
131+
device_count = xr.global_runtime_device_count()
132+
if device_count < 4:
133+
self.skipTest("Need at least 4 devices for 2D mesh")
134+
135+
mesh_shape = (2, device_count // 2)
136+
xla_tensor, mesh = self._create_test_tensor_and_mesh(
137+
(100, 50), mesh_shape, [Shard(0), Replicate()])
138+
spec = xla_tensor._spec
139+
140+
assert len(spec.placements) == 2
141+
assert isinstance(spec.placements[0], Shard)
142+
assert isinstance(spec.placements[1], Replicate)
143+
144+
145+
if __name__ == '__main__':
146+
test = unittest.main()
147+
sys.exit(0 if test.result.wasSuccessful() else 1)

test/tpu/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ run_test "$_TEST_DIR/spmd/test_xla_spmd_python_api_interaction.py"
6161
run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py"
6262
run_test "$_TEST_DIR/spmd/test_fsdp_v2.py"
6363
run_test "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py"
64+
run_test "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py"
6465
run_test "$_TEST_DIR/test_gradient_accumulation.py"
6566
XLA_EXPERIMENTAL=nonzero:masked_select:nms run_test "$_TEST_DIR/ds/test_dynamic_shape_models.py" -v
6667
run_test "$_TEST_DIR/test_autocast.py"

torch_xla/distributed/spmd/xla_sharded_tensor.py

Lines changed: 84 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
from typing import List, Tuple, Iterator, Union
77
import contextlib
88
import collections
9+
import torch_xla.runtime as xr
10+
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
11+
from torch.distributed.device_mesh import DeviceMesh
12+
from torch.distributed.tensor.placement_types import Shard, Replicate
913

1014

1115
@dataclass
@@ -91,10 +95,15 @@ class XLAShardedTensor(torch.Tensor):
9195
# >> assert len(input.shape) == len(partition_spec)
9296
partition_spec: Tuple[int, None]
9397

94-
__slots__ = ['global_tensor']
98+
__slots__ = ['global_tensor', 'mesh_shape', 'partition_spec', '_cached_spec']
9599

96100
@staticmethod
97-
def __new__(cls, elem: torch.Tensor, *args, **kwargs):
101+
def __new__(cls,
102+
elem: torch.Tensor,
103+
mesh_shape=None,
104+
partition_spec=None,
105+
*args,
106+
**kwargs):
98107
# TODO(yeounoh) wrapper can take different arguments
99108
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
100109
cls,
@@ -106,6 +115,11 @@ def __new__(cls, elem: torch.Tensor, *args, **kwargs):
106115
device=elem.device,
107116
requires_grad=kwargs.get("requires_grad", False))
108117
r.global_tensor = elem.detach() if r.requires_grad else elem
118+
# Store mesh and partition information for DTensor compatibility
119+
if mesh_shape is not None:
120+
r.mesh_shape = mesh_shape
121+
if partition_spec is not None:
122+
r.partition_spec = partition_spec
109123
return r
110124

111125
# Shards on the devices are materialized/available after the lazy
@@ -159,7 +173,27 @@ def unwrap(elem):
159173
return elem.global_tensor if isinstance(elem, XLAShardedTensor) else elem
160174

161175
def wrap(elem):
162-
return XLAShardedTensor(elem) if isinstance(elem, torch.Tensor) else elem
176+
if isinstance(elem,
177+
torch.Tensor) and not isinstance(elem, XLAShardedTensor):
178+
# Try to get mesh/partition info from any XLAShardedTensor in args
179+
mesh_shape = None
180+
partition_spec = None
181+
182+
def find_sharded_info(x):
183+
nonlocal mesh_shape, partition_spec
184+
if isinstance(x, XLAShardedTensor):
185+
if hasattr(x, 'mesh_shape') and x.mesh_shape:
186+
mesh_shape = x.mesh_shape
187+
if hasattr(x, 'partition_spec') and x.partition_spec:
188+
partition_spec = x.partition_spec
189+
190+
tree_map(find_sharded_info, args)
191+
if kwargs:
192+
tree_map(find_sharded_info, kwargs)
193+
194+
return XLAShardedTensor(
195+
elem, mesh_shape=mesh_shape, partition_spec=partition_spec)
196+
return elem
163197

164198
# no_dispatch is only needed if you use enable_python_mode.
165199
# It prevents infinite recursion.
@@ -169,6 +203,53 @@ def wrap(elem):
169203
func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
170204
return rs
171205

206+
@property
207+
def _spec(self):
208+
"""
209+
Convert XLA sharding information to DTensorSpec for DTensor interface compatibility.
210+
"""
211+
# Return cached spec if available
212+
if hasattr(self, '_cached_spec'):
213+
return self._cached_spec
214+
215+
# use existing mesh_shape
216+
if hasattr(self, 'mesh_shape') and self.mesh_shape:
217+
import torch_xla.runtime as xr
218+
device_count = xr.global_runtime_device_count()
219+
device_list = list(range(device_count))
220+
mesh = DeviceMesh("xla",
221+
torch.tensor(device_list).reshape(self.mesh_shape))
222+
else:
223+
raise ValueError("mesh_shape must be specified to create DTensorSpec")
224+
225+
# use existing partition_spec
226+
if hasattr(self, 'partition_spec') and self.partition_spec:
227+
placements = []
228+
for mesh_dim in range(
229+
len(self.mesh_shape
230+
) if hasattr(self, 'mesh_shape') and self.mesh_shape else 1):
231+
# find tensor dimension sharded on this mesh dimension
232+
tensor_dim = None
233+
for t_dim, m_dim in enumerate(self.partition_spec):
234+
if m_dim == mesh_dim:
235+
tensor_dim = t_dim
236+
break
237+
placements.append(
238+
Shard(tensor_dim) if tensor_dim is not None else Replicate())
239+
else:
240+
raise ValueError("partition_spec must be specified to create DTensorSpec")
241+
242+
# tensor metadata
243+
tensor_meta = TensorMeta(
244+
shape=self.global_tensor.shape,
245+
stride=self.global_tensor.stride(),
246+
dtype=self.global_tensor.dtype)
247+
248+
# Create and cache the spec
249+
self._cached_spec = DTensorSpec(
250+
mesh=mesh, placements=tuple(placements), tensor_meta=tensor_meta)
251+
return self._cached_spec
252+
172253
@classmethod
173254
def __torch_function__(cls, func, types, args=(), kwargs=None):
174255
return super().__torch_function__(func, types, args, kwargs)

torch_xla/distributed/spmd/xla_sharding.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,8 @@ def enable_manual_sharding(t: Union[torch.Tensor, XLAShardedTensor],
543543
mesh = get_global_mesh() if mesh is None else mesh
544544
t = mark_sharding(unwrap_sharded_tensor(t), mesh, partition_spec)
545545
t = torch_xla._XLAC._spmd_full_to_shard_shape(unwrap_sharded_tensor(t))
546-
return wrap_as_sharded_tensor(t)
546+
return wrap_as_sharded_tensor(
547+
t, mesh_shape=mesh.mesh_shape, partition_spec=partition_spec)
547548

548549

549550
def disable_manual_sharding(t: Union[torch.Tensor, XLAShardedTensor],
@@ -560,7 +561,8 @@ def disable_manual_sharding(t: Union[torch.Tensor, XLAShardedTensor],
560561
t = torch_xla._XLAC._spmd_shard_to_full_shape(
561562
unwrap_sharded_tensor(t), mesh.get_op_sharding(partition_spec),
562563
full_shape, t.dtype)
563-
return wrap_as_sharded_tensor(t)
564+
return wrap_as_sharded_tensor(
565+
t, mesh_shape=mesh.mesh_shape, partition_spec=partition_spec)
564566

565567

566568
def annotate_custom_sharding(t: Union[torch.Tensor,
@@ -594,7 +596,8 @@ def annotate_custom_sharding(t: Union[torch.Tensor,
594596
op_sharding = mesh.get_op_sharding(partition_spec)
595597
annotate_func = torch_xla._XLAC._xla_annotate_custom_sharding
596598
annotate_func(unwrap_sharded_tensor(t), op_sharding)
597-
return wrap_as_sharded_tensor(t)
599+
return wrap_as_sharded_tensor(
600+
t, mesh_shape=mesh.mesh_shape, partition_spec=partition_spec)
598601

599602

600603
def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
@@ -651,7 +654,9 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
651654
op_sharding = mesh.get_op_sharding(partition_spec)
652655
annotate_func = torch_xla._XLAC._xla_mark_sharding
653656
annotate_func(unwrap_sharded_tensor(t), op_sharding)
654-
return wrap_as_sharded_tensor(t)
657+
# Pass mesh and partition spec information for DTensor compatibility
658+
return wrap_as_sharded_tensor(
659+
t, mesh_shape=mesh.mesh_shape, partition_spec=partition_spec)
655660

656661

657662
def mark_sharding_with_gradients(
@@ -755,11 +760,19 @@ def clear_sharding(t: Union[torch.Tensor, XLAShardedTensor]) -> torch.Tensor:
755760
return t
756761

757762

758-
def wrap_as_sharded_tensor(
759-
t: Union[torch.Tensor, XLAShardedTensor]) -> XLAShardedTensor:
763+
def wrap_as_sharded_tensor(t: Union[torch.Tensor, XLAShardedTensor],
764+
mesh_shape=None,
765+
partition_spec=None) -> XLAShardedTensor:
766+
# pass along mesh and partition spec information
760767
if not isinstance(t, XLAShardedTensor):
761-
return XLAShardedTensor(t)
762-
return t
768+
return XLAShardedTensor(
769+
t, mesh_shape=mesh_shape, partition_spec=partition_spec)
770+
else:
771+
if mesh_shape is not None:
772+
t.mesh_shape = mesh_shape
773+
if partition_spec is not None:
774+
t.partition_spec = partition_spec
775+
return t
763776

764777

765778
def unwrap_sharded_tensor(

0 commit comments

Comments
 (0)