Skip to content

Commit a16b53e

Browse files
committed
Part 1: Disambiguate custom sharding op for DeviceData IR nodes
1 parent 1f9dd8f commit a16b53e

File tree

6 files changed

+120
-12
lines changed

6 files changed

+120
-12
lines changed

test/spmd/test_xla_sharding.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1686,6 +1686,48 @@ def test_shard_as(self):
16861686
self.assertIn(sharding_spec, x_sharding)
16871687
self.assertEqual(x_sharding, y_sharding)
16881688

1689+
@unittest.skipIf(xr.global_runtime_device_count() == 1,
1690+
"Multiple devices needed")
1691+
def test_annotate_custom_sharding(self):
1692+
xt = torch.randn(2, 4, 64, 64).to(xm.xla_device())
1693+
sharded_mesh_axis_0 = self.n_devices // 2
1694+
sharded_mesh_axis_1 = self.n_devices // sharded_mesh_axis_0
1695+
1696+
xs.mark_sharding(
1697+
xt, self._get_mesh((1, 1, sharded_mesh_axis_0, sharded_mesh_axis_1)),
1698+
(0, 1, 2, 3))
1699+
original_sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(xt)
1700+
1701+
# Attempting to reshard the original tensor should result in a failure
1702+
with self.assertRaises(RuntimeError):
1703+
xs.mark_sharding(xt, self._get_mesh((1, 1, 1, self.n_devices)),
1704+
(0, 1, 2, 3))
1705+
1706+
self.assertEqual(original_sharding_spec,
1707+
torch_xla._XLAC._get_xla_sharding_spec(xt))
1708+
1709+
# Annotate the existing XLAShardedTensor with a custom sharding IR
1710+
xs.annotate_custom_sharding(xt, self._get_mesh((1, 1, 1, self.n_devices)),
1711+
(0, 1, 2, 3))
1712+
1713+
custom_sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(xt)
1714+
1715+
self.assertEqual(custom_sharding_spec,
1716+
torch_xla._XLAC._get_xla_sharding_spec(xt))
1717+
self.assertNotEqual(custom_sharding_spec, original_sharding_spec)
1718+
1719+
hlo = torch_xla._XLAC._get_xla_tensors_hlo([xt])
1720+
self.assertIn(
1721+
f'%p0.1 = f32[2,4,64,64]{{3,2,1,0}} parameter(0), sharding={original_sharding_spec}',
1722+
hlo)
1723+
self.assertIn(
1724+
f'%custom-call.2 = f32[2,4,64,64]{{3,2,1,0}} custom-call(f32[2,4,64,64]{{3,2,1,0}} %p0.1), custom_call_target="Sharding", sharding={custom_sharding_spec}',
1725+
hlo)
1726+
xm.mark_step()
1727+
# Ensure that the resulting sharding spec is preserved
1728+
self.assertEqual(custom_sharding_spec,
1729+
torch_xla._XLAC._get_xla_sharding_spec(xt))
1730+
16891731

16901732
if __name__ == '__main__':
16911733
test = unittest.main()

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2237,6 +2237,11 @@ void InitXlaModuleBindings(py::module m) {
22372237
[](const at::Tensor& input, xla::OpSharding sharding) {
22382238
ShardingUtil::XlaMarkSharding(input, sharding);
22392239
});
2240+
m.def("_xla_annotate_custom_sharding",
2241+
[](const at::Tensor& input, xla::OpSharding sharding) {
2242+
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
2243+
ShardingUtil::XlaAnnotateCustomSharding(xtensor, sharding);
2244+
});
22402245
m.def("_mark_manual_sharding",
22412246
[](const at::Tensor& input, xla::OpSharding sharding) {
22422247
XLA_CHECK(IsNonDeviceDataIR(input))

torch_xla/csrc/xla_sharding_util.cpp

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -766,23 +766,24 @@ void ShardingUtil::XlaMarkSharding(const at::Tensor& input,
766766
XLA_CHECK(sharding.type() != xla::OpSharding::UNKNOWN)
767767
<< "Can't explicilty annotate with UNKNOWN sharding type.";
768768
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
769-
XLATensor::ShardingSpecPtr new_sharding_spec =
770-
std::make_shared<XLATensor::ShardingSpec>(
771-
sharding, MakeShapeWithDeviceLayout(
772-
xtensor->shape(), static_cast<XlaDeviceType>(
773-
xtensor->GetDevice().type())));
774769

775-
// For Non DeviceData IR values, we directly attach the sharding spec
776-
// to the xtensor.
770+
// For Non DeviceData IR values, we directly attach the sharding spec to the
771+
// xtensor.
777772
const DeviceData* device_data_node = nullptr;
778773
if (xtensor->CurrentIrValue()) {
779774
device_data_node = DeviceData::Cast(xtensor->CurrentIrValue().node.get());
780775
if (!device_data_node) {
781-
tensor_methods::custom_sharding_(xtensor, new_sharding_spec);
776+
XlaAnnotateCustomSharding(xtensor, sharding);
782777
return;
783778
}
784779
}
785780

781+
XLATensor::ShardingSpecPtr new_sharding_spec =
782+
std::make_shared<XLATensor::ShardingSpec>(
783+
sharding, MakeShapeWithDeviceLayout(
784+
xtensor->shape(), static_cast<XlaDeviceType>(
785+
xtensor->GetDevice().type())));
786+
786787
// For data, we need to deal with the data transfers between
787788
// host and device.
788789
at::Tensor cpu_tensor;
@@ -816,11 +817,12 @@ void ShardingUtil::XlaMarkSharding(const at::Tensor& input,
816817
// tensor from the physical device to CPU. In that case, the value
817818
// must be present on the backend device.
818819
XLA_CHECK((xtensor->CurrentDataHandle() &&
819-
xtensor->CurrentDataHandle()->HasValue()) ||
820-
device_data_node != nullptr)
820+
xtensor->CurrentDataHandle()->HasValue()))
821821
<< "Cannot shard tensor. Data does not present on any device.";
822822
std::vector<XLATensorPtr> xla_tensors{xtensor};
823-
cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0];
823+
auto tensors = XLAGraphExecutor::Get()->GetTensors(&xla_tensors);
824+
XLA_CHECK_EQ(tensors.size(), 1);
825+
cpu_tensor = tensors[0];
824826
}
825827
auto xla_data = CreateTensorsData(
826828
std::vector<at::Tensor>{cpu_tensor},
@@ -833,6 +835,23 @@ void ShardingUtil::XlaMarkSharding(const at::Tensor& input,
833835
XLAGraphExecutor::Get()->RegisterTensor(xtensor->data());
834836
}
835837

838+
void ShardingUtil::XlaAnnotateCustomSharding(const XLATensorPtr& input,
839+
xla::OpSharding sharding) {
840+
TORCH_LAZY_COUNTER("XlaAnnotateCustomSharding", 1);
841+
842+
XLA_CHECK(UseVirtualDevice())
843+
<< "Please enable SPMD via `torch_xla.runtime.use_spmd()`";
844+
XLA_CHECK(sharding.type() != xla::OpSharding::UNKNOWN)
845+
<< "Can't explicilty annotate with UNKNOWN sharding type.";
846+
847+
XLATensor::ShardingSpecPtr sharding_spec =
848+
std::make_shared<XLATensor::ShardingSpec>(
849+
sharding, MakeShapeWithDeviceLayout(
850+
input->shape(),
851+
static_cast<XlaDeviceType>(input->GetDevice().type())));
852+
tensor_methods::custom_sharding_(input, sharding_spec);
853+
}
854+
836855
void ShardingUtil::SetAutoSharding() {
837856
// This stays on throughout the program.
838857
use_auto_sharding = true;

torch_xla/csrc/xla_sharding_util.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,12 @@ class ShardingUtil {
123123
static void XlaMarkSharding(const at::Tensor& input,
124124
xla::OpSharding sharding);
125125

126+
// Add a custom sharding node IR to an XLATensor. Note that unlike
127+
// XlaMarkSharding, this will not explicitly set a sharding spec tied to the
128+
// DeviceData node, nor transfer any sharded data to the device. This serves
129+
// merely as an XLA custom sharding annotation IR.
130+
static void XlaAnnotateCustomSharding(const XLATensorPtr& input,
131+
xla::OpSharding sharding);
126132
//////////////////////////// Auto-Sharding ////////////////////////////
127133

128134
// Construct a device mesh for auto-sharding pass. Returns a tuple of mesh

torch_xla/distributed/spmd/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
mark_sharding, mark_sharding_with_gradients, clear_sharding, get_1d_mesh,
55
wrap_if_sharded, xla_patched_nn_linear_forward, set_global_mesh,
66
get_global_mesh, _mark_manual_sharding, enable_manual_sharding,
7-
disable_manual_sharding, apply_backward_optimization_barrier, shard_as)
7+
disable_manual_sharding, apply_backward_optimization_barrier, shard_as,
8+
annotate_custom_sharding)
89
from .api import xla_distribute_tensor, xla_distribute_module, auto_policy
910

1011
__all__ = [
@@ -20,6 +21,7 @@
2021
"mark_sharding",
2122
"mark_sharding_with_gradients",
2223
"shard_as",
24+
"annotate_custom_sharding",
2325
"clear_sharding",
2426
"get_1d_mesh",
2527
"wrap_if_sharded",

torch_xla/distributed/spmd/xla_sharding.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,40 @@ def disable_manual_sharding(t: Union[torch.Tensor, XLAShardedTensor],
563563
return wrap_as_sharded_tensor(t)
564564

565565

566+
def annotate_custom_sharding(t: Union[torch.Tensor,
567+
XLAShardedTensor], mesh: Mesh,
568+
partition_spec: PartitionSpec) -> XLAShardedTensor:
569+
"""
570+
Annotates an existing tensor with a custom sharding IR node without modifying its data layout.
571+
572+
Unlike `mark_sharding`, this function only adds a custom sharding annotation to the XLA IR
573+
without explicitly setting a sharding spec tied to the DeviceData node or transferring any
574+
sharded data to the device. This allows providing explicit XLA sharding annotations of tensors
575+
that have already been sharded with `mark_sharding`.
576+
577+
Args:
578+
t: The input tensor to be annotated with custom sharding.
579+
mesh: The device mesh that specifies the logical device topology.
580+
partition_spec: The partitioning specification for each dimension of the input tensor.
581+
582+
Returns:
583+
XLAShardedTensor: The input tensor wrapped as a sharded tensor with the custom sharding annotation.
584+
585+
Example:
586+
>>> # First shard the tensor with mark_sharding
587+
>>> sharded_tensor = xs.mark_sharding(tensor, mesh1, (0, 1, 2, 3))
588+
>>> # Later, annotate with a different sharding for the XLA SPMD partitioner
589+
>>> custom_sharded = xs.annotate_custom_sharding(sharded_tensor, mesh2, (0, 1, 2, 3))
590+
"""
591+
assert len(t.shape) == len(partition_spec), \
592+
f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})."
593+
594+
op_sharding = mesh.get_op_sharding(partition_spec)
595+
annotate_func = torch_xla._XLAC._xla_annotate_custom_sharding
596+
annotate_func(unwrap_sharded_tensor(t), op_sharding)
597+
return wrap_as_sharded_tensor(t)
598+
599+
566600
def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
567601
partition_spec: PartitionSpec) -> XLAShardedTensor:
568602
"""

0 commit comments

Comments
 (0)