diff --git a/test/cpp/test_xla_sharding.cpp b/test/cpp/test_xla_sharding.cpp index 16523cb81ee4..e8412affd35c 100644 --- a/test/cpp/test_xla_sharding.cpp +++ b/test/cpp/test_xla_sharding.cpp @@ -41,6 +41,34 @@ class XLAShardingTest : public AtenXlaTensorTestBase { } }; +TEST_F(XLAShardingTest, NormalizeTileAssignment) { + // Test with an empty tile assignment + std::vector empty_tile_assignment = {}; + auto normalized = + ShardingUtil::NormalizeTileAssignment(empty_tile_assignment); + EXPECT_TRUE(normalized.empty()); + + // Test with positive values + std::vector positive_tile_assignment = {3, 1, 4, 2}; + normalized = ShardingUtil::NormalizeTileAssignment(positive_tile_assignment); + EXPECT_EQ(normalized, std::vector({2, 0, 3, 1})); + + // Test with all identical values + std::vector identical_tile_assignment = {5, 5, 5, 5}; + normalized = ShardingUtil::NormalizeTileAssignment(identical_tile_assignment); + EXPECT_EQ(normalized, std::vector({0, 0, 0, 0})); + + // Test with negative values + std::vector negative_tile_assignment = {-3, -1, -4, -2}; + EXPECT_THROW(ShardingUtil::NormalizeTileAssignment(negative_tile_assignment), + std::runtime_error); + + // Test with mixed positive and negative values + std::vector mixed_tile_assignment = {3, -1, 4, 2}; + EXPECT_THROW(ShardingUtil::NormalizeTileAssignment(mixed_tile_assignment), + std::runtime_error); +} + TEST_F(XLAShardingTest, GetShardShape) { auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat)); xla::Shape tensor_shape = diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index 3ddec53d7004..36de2864cacc 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -218,6 +218,36 @@ bool ShardingUtil::EqualOpShardings(const xla::OpSharding& a, return xla::protobuf_util::HaveSameSerialization(a, b); } +// function to normalize tile_assignment +std::vector ShardingUtil::NormalizeTileAssignment( + const std::vector& tile_assignment) { + // Check if the tile_assignment is empty + if (tile_assignment.empty()) { + TF_LOG(WARNING) << "Invalid argument: tile_assignment is empty"; + return tile_assignment; + } + + // Find the minimum value in the tile_assignment + int64_t min_value = + *std::min_element(tile_assignment.begin(), tile_assignment.end()); + + // check if min_value of tile_assignment is positive + XLA_CHECK(min_value >= 0) + << "min_value of tile_assignment cannot be negative"; + + // Create a vector to store the normalized tile_assignment + std::vector normalized_tile_assignment; + normalized_tile_assignment.reserve( + tile_assignment.size()); // Reserve space to avoid reallocations + + // Normalize each device ID by subtracting the minimum value + for (const auto& device : tile_assignment) { + normalized_tile_assignment.push_back(device - min_value); + } + + return normalized_tile_assignment; +} + xla::OpSharding ShardingUtil::CreateOpSharding( const py::list& tile_assignment, const py::list& group_assignment, const py::list& replication_groups, ShardingType sharding_type) { diff --git a/torch_xla/csrc/xla_sharding_util.h b/torch_xla/csrc/xla_sharding_util.h index 8b8b98653b2f..a6a624e5720b 100644 --- a/torch_xla/csrc/xla_sharding_util.h +++ b/torch_xla/csrc/xla_sharding_util.h @@ -45,6 +45,10 @@ class ShardingUtil { static bool EqualOpShardings(const xla::OpSharding& a, const xla::OpSharding& b); + // Returns tile_assignment after normalizing + static std::vector NormalizeTileAssignment( + const std::vector& tile_assignment); + // Creates an xla::OpSharding. `tile_assignmnent` is required for TILED // `sharding_type` and `replication_groups` for `PARTIAL`. static xla::OpSharding CreateOpSharding(const py::list& tile_assignment,