Skip to content

Introduce annotate_custom_sharding binding #9203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1686,6 +1686,48 @@ def test_shard_as(self):
self.assertIn(sharding_spec, x_sharding)
self.assertEqual(x_sharding, y_sharding)

@unittest.skipIf(xr.global_runtime_device_count() == 1,
"Multiple devices needed")
def test_annotate_custom_sharding(self):
xt = torch.randn(2, 4, 64, 64).to(xm.xla_device())
sharded_mesh_axis_0 = self.n_devices // 2
sharded_mesh_axis_1 = self.n_devices // sharded_mesh_axis_0

xs.mark_sharding(
xt, self._get_mesh((1, 1, sharded_mesh_axis_0, sharded_mesh_axis_1)),
(0, 1, 2, 3))
original_sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(xt)

# Attempting to reshard the original tensor should result in a failure
with self.assertRaises(RuntimeError):
xs.mark_sharding(xt, self._get_mesh((1, 1, 1, self.n_devices)),
(0, 1, 2, 3))

self.assertEqual(original_sharding_spec,
torch_xla._XLAC._get_xla_sharding_spec(xt))

# Annotate the existing XLAShardedTensor with a custom sharding IR
xs.annotate_custom_sharding(xt, self._get_mesh((1, 1, 1, self.n_devices)),
(0, 1, 2, 3))

custom_sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(xt)

self.assertEqual(custom_sharding_spec,
torch_xla._XLAC._get_xla_sharding_spec(xt))
self.assertNotEqual(custom_sharding_spec, original_sharding_spec)

hlo = torch_xla._XLAC._get_xla_tensors_hlo([xt])
self.assertIn(
f'%p0.1 = f32[2,4,64,64]{{3,2,1,0}} parameter(0), sharding={original_sharding_spec}',
hlo)
self.assertIn(
f'%custom-call.2 = f32[2,4,64,64]{{3,2,1,0}} custom-call(f32[2,4,64,64]{{3,2,1,0}} %p0.1), custom_call_target="Sharding", sharding={custom_sharding_spec}',
hlo)
xm.mark_step()
# Ensure that the resulting sharding spec is preserved
self.assertEqual(custom_sharding_spec,
torch_xla._XLAC._get_xla_sharding_spec(xt))


if __name__ == '__main__':
test = unittest.main()
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2237,6 +2237,11 @@ void InitXlaModuleBindings(py::module m) {
[](const at::Tensor& input, xla::OpSharding sharding) {
ShardingUtil::XlaMarkSharding(input, sharding);
});
m.def("_xla_annotate_custom_sharding",
[](const at::Tensor& input, xla::OpSharding sharding) {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
ShardingUtil::XlaAnnotateCustomSharding(xtensor, sharding);
});
m.def("_mark_manual_sharding",
[](const at::Tensor& input, xla::OpSharding sharding) {
XLA_CHECK(IsNonDeviceDataIR(input))
Expand Down
38 changes: 29 additions & 9 deletions torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -766,23 +766,24 @@ void ShardingUtil::XlaMarkSharding(const at::Tensor& input,
XLA_CHECK(sharding.type() != xla::OpSharding::UNKNOWN)
<< "Can't explicilty annotate with UNKNOWN sharding type.";
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
XLATensor::ShardingSpecPtr new_sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(
sharding, MakeShapeWithDeviceLayout(
xtensor->shape(), static_cast<XlaDeviceType>(
xtensor->GetDevice().type())));

// For Non DeviceData IR values, we directly attach the sharding spec
// to the xtensor.
// For Non DeviceData IR values, we directly attach the sharding spec to the
// xtensor.
const DeviceData* device_data_node = nullptr;
if (xtensor->CurrentIrValue()) {
device_data_node = DeviceData::Cast(xtensor->CurrentIrValue().node.get());
if (!device_data_node) {
tensor_methods::custom_sharding_(xtensor, new_sharding_spec);
XlaAnnotateCustomSharding(xtensor, sharding);
return;
}
}

XLATensor::ShardingSpecPtr new_sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(
sharding, MakeShapeWithDeviceLayout(
xtensor->shape(), static_cast<XlaDeviceType>(
xtensor->GetDevice().type())));

// For data, we need to deal with the data transfers between
// host and device.
at::Tensor cpu_tensor;
Expand Down Expand Up @@ -820,7 +821,9 @@ void ShardingUtil::XlaMarkSharding(const at::Tensor& input,
device_data_node != nullptr)
<< "Cannot shard tensor. Data does not present on any device.";
std::vector<XLATensorPtr> xla_tensors{xtensor};
cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0];
auto tensors = XLAGraphExecutor::Get()->GetTensors(&xla_tensors);
XLA_CHECK_EQ(tensors.size(), 1);
cpu_tensor = tensors[0];
}
auto xla_data = CreateTensorsData(
std::vector<at::Tensor>{cpu_tensor},
Expand All @@ -833,6 +836,23 @@ void ShardingUtil::XlaMarkSharding(const at::Tensor& input,
XLAGraphExecutor::Get()->RegisterTensor(xtensor->data());
}

void ShardingUtil::XlaAnnotateCustomSharding(const XLATensorPtr& input,
xla::OpSharding sharding) {
TORCH_LAZY_COUNTER("XlaAnnotateCustomSharding", 1);

XLA_CHECK(UseVirtualDevice())
<< "Please enable SPMD via `torch_xla.runtime.use_spmd()`";
XLA_CHECK(sharding.type() != xla::OpSharding::UNKNOWN)
<< "Can't explicilty annotate with UNKNOWN sharding type.";

XLATensor::ShardingSpecPtr sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(
sharding, MakeShapeWithDeviceLayout(
input->shape(),
static_cast<XlaDeviceType>(input->GetDevice().type())));
tensor_methods::custom_sharding_(input, sharding_spec);
}

void ShardingUtil::SetAutoSharding() {
// This stays on throughout the program.
use_auto_sharding = true;
Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/xla_sharding_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,12 @@ class ShardingUtil {
static void XlaMarkSharding(const at::Tensor& input,
xla::OpSharding sharding);

// Add a custom sharding node IR to an XLATensor. Note that unlike
// XlaMarkSharding, this will not explicitly set a sharding spec tied to the
// DeviceData node, nor transfer any sharded data to the device. This serves
// merely as an XLA custom sharding annotation IR.
static void XlaAnnotateCustomSharding(const XLATensorPtr& input,
xla::OpSharding sharding);
//////////////////////////// Auto-Sharding ////////////////////////////

// Construct a device mesh for auto-sharding pass. Returns a tuple of mesh
Expand Down
4 changes: 3 additions & 1 deletion torch_xla/distributed/spmd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
mark_sharding, mark_sharding_with_gradients, clear_sharding, get_1d_mesh,
wrap_if_sharded, xla_patched_nn_linear_forward, set_global_mesh,
get_global_mesh, _mark_manual_sharding, enable_manual_sharding,
disable_manual_sharding, apply_backward_optimization_barrier, shard_as)
disable_manual_sharding, apply_backward_optimization_barrier, shard_as,
annotate_custom_sharding)
from .api import xla_distribute_tensor, xla_distribute_module, auto_policy

__all__ = [
Expand All @@ -20,6 +21,7 @@
"mark_sharding",
"mark_sharding_with_gradients",
"shard_as",
"annotate_custom_sharding",
"clear_sharding",
"get_1d_mesh",
"wrap_if_sharded",
Expand Down
34 changes: 34 additions & 0 deletions torch_xla/distributed/spmd/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,40 @@ def disable_manual_sharding(t: Union[torch.Tensor, XLAShardedTensor],
return wrap_as_sharded_tensor(t)


def annotate_custom_sharding(t: Union[torch.Tensor,
XLAShardedTensor], mesh: Mesh,
partition_spec: PartitionSpec) -> XLAShardedTensor:
"""
Annotates an existing tensor with a custom sharding IR node without modifying its data layout.

Unlike `mark_sharding`, this function only adds a custom sharding annotation to the XLA IR
without explicitly setting a sharding spec tied to the DeviceData node or transferring any
sharded data to the device. This allows providing explicit XLA sharding annotations of tensors
that have already been sharded with `mark_sharding`.

Args:
t: The input tensor to be annotated with custom sharding.
mesh: The device mesh that specifies the logical device topology.
partition_spec: The partitioning specification for each dimension of the input tensor.

Returns:
XLAShardedTensor: The input tensor wrapped as a sharded tensor with the custom sharding annotation.

Example:
>>> # First shard the tensor with mark_sharding
>>> sharded_tensor = xs.mark_sharding(tensor, mesh1, (0, 1, 2, 3))
>>> # Later, annotate with a different sharding for the XLA SPMD partitioner
>>> custom_sharded = xs.annotate_custom_sharding(sharded_tensor, mesh2, (0, 1, 2, 3))
"""
assert len(t.shape) == len(partition_spec), \
f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})."

op_sharding = mesh.get_op_sharding(partition_spec)
annotate_func = torch_xla._XLAC._xla_annotate_custom_sharding
annotate_func(unwrap_sharded_tensor(t), op_sharding)
return wrap_as_sharded_tensor(t)


def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
partition_spec: PartitionSpec) -> XLAShardedTensor:
"""
Expand Down
Loading