Skip to content

feat: add normalize_tile_assignment function needed for local SPMD #9436

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions test/cpp/test_xla_sharding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,34 @@ class XLAShardingTest : public AtenXlaTensorTestBase {
}
};

TEST_F(XLAShardingTest, NormalizeTileAssignment) {
// Test with an empty tile assignment
std::vector<int64_t> empty_tile_assignment = {};
auto normalized =
ShardingUtil::NormalizeTileAssignment(empty_tile_assignment);
EXPECT_TRUE(normalized.empty());

// Test with positive values
std::vector<int64_t> positive_tile_assignment = {3, 1, 4, 2};
normalized = ShardingUtil::NormalizeTileAssignment(positive_tile_assignment);
EXPECT_EQ(normalized, std::vector<int64_t>({2, 0, 3, 1}));

// Test with all identical values
std::vector<int64_t> identical_tile_assignment = {5, 5, 5, 5};
normalized = ShardingUtil::NormalizeTileAssignment(identical_tile_assignment);
EXPECT_EQ(normalized, std::vector<int64_t>({0, 0, 0, 0}));

// Test with negative values
std::vector<int64_t> negative_tile_assignment = {-3, -1, -4, -2};
EXPECT_THROW(ShardingUtil::NormalizeTileAssignment(negative_tile_assignment),
std::runtime_error);

// Test with mixed positive and negative values
std::vector<int64_t> mixed_tile_assignment = {3, -1, 4, 2};
EXPECT_THROW(ShardingUtil::NormalizeTileAssignment(mixed_tile_assignment),
std::runtime_error);
}

TEST_F(XLAShardingTest, GetShardShape) {
auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat));
xla::Shape tensor_shape =
Expand Down
30 changes: 30 additions & 0 deletions torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,36 @@ bool ShardingUtil::EqualOpShardings(const xla::OpSharding& a,
return xla::protobuf_util::HaveSameSerialization(a, b);
}

// function to normalize tile_assignment
std::vector<int64_t> ShardingUtil::NormalizeTileAssignment(
const std::vector<int64_t>& tile_assignment) {
// Check if the tile_assignment is empty
if (tile_assignment.empty()) {
TF_LOG(WARNING) << "Invalid argument: tile_assignment is empty";
return tile_assignment;
}

// Find the minimum value in the tile_assignment
int64_t min_value =
*std::min_element(tile_assignment.begin(), tile_assignment.end());

// check if min_value of tile_assignment is positive
XLA_CHECK(min_value >= 0)
<< "min_value of tile_assignment cannot be negative";

// Create a vector to store the normalized tile_assignment
std::vector<int64_t> normalized_tile_assignment;
normalized_tile_assignment.reserve(
tile_assignment.size()); // Reserve space to avoid reallocations

// Normalize each device ID by subtracting the minimum value
for (const auto& device : tile_assignment) {
normalized_tile_assignment.push_back(device - min_value);
}

return normalized_tile_assignment;
}

xla::OpSharding ShardingUtil::CreateOpSharding(
const py::list& tile_assignment, const py::list& group_assignment,
const py::list& replication_groups, ShardingType sharding_type) {
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/xla_sharding_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ class ShardingUtil {
static bool EqualOpShardings(const xla::OpSharding& a,
const xla::OpSharding& b);

// Returns tile_assignment after normalizing
static std::vector<int64_t> NormalizeTileAssignment(
const std::vector<int64_t>& tile_assignment);

// Creates an xla::OpSharding. `tile_assignmnent` is required for TILED
// `sharding_type` and `replication_groups` for `PARTIAL`.
static xla::OpSharding CreateOpSharding(const py::list& tile_assignment,
Expand Down
Loading