Skip to content

Error Handling: propagate status for ReleaseGilAndTransferData and XlaDataToTensors. #9431

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: ysiraichi/status-for-oom-errors
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/scripts/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ function run_torch_xla_cpp_tests() {
#"test_xla_backend_intf"
"test_xla_sharding"
"test_runtime"
"test_status"
"test_status_dont_show_cpp_error_context"
"test_status_show_cpp_error_context")
for name in "${test_names[@]}"; do
Expand Down
1 change: 0 additions & 1 deletion BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ test_suite(
"//test/cpp:test_tensor",
"//test/cpp:test_xla_sharding",
"//test/cpp:test_runtime",
"//test/cpp:test_status",
"//test/cpp:test_status_dont_show_cpp_error_context",
"//test/cpp:test_status_show_cpp_error_context",
"//torch_xla/csrc/runtime:pjrt_computation_client_test",
Expand Down
10 changes: 1 addition & 9 deletions test/cpp/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ ptxla_cc_library(
"//torch_xla/csrc/runtime:runtime",
"//torch_xla/csrc/runtime:debug_macros",
"//torch_xla/csrc/runtime:sys_util",
"//torch_xla/csrc:status",
"//torch_xla/csrc:tensor",
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest",
Expand Down Expand Up @@ -159,15 +160,6 @@ ptxla_cc_test(
],
)

ptxla_cc_test(
name = "test_status",
srcs = ["test_status.cpp"],
deps = [
"//torch_xla/csrc:status",
"@com_google_googletest//:gtest_main",
],
)

ptxla_cc_test(
name = "test_status_dont_show_cpp_error_context",
srcs = ["test_status_dont_show_cpp_error_context.cpp"],
Expand Down
6 changes: 3 additions & 3 deletions test/cpp/cpp_test_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/runtime.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/status.h"
#include "torch_xla/csrc/tensor_impl.h"
#include "torch_xla/csrc/tensor_util.h"
#include "torch_xla/csrc/torch_util.h"
Expand Down Expand Up @@ -301,9 +302,8 @@ std::vector<torch_xla::runtime::ComputationClient::DataPtr> Execute(
std::vector<at::Tensor> Fetch(
absl::Span<const torch_xla::runtime::ComputationClient::DataPtr>
device_data) {
std::vector<xla::Literal> literals =
torch_xla::runtime::GetComputationClientOrDie()->TransferFromDevice(
device_data);
std::vector<xla::Literal> literals = GetValueOrThrow(
runtime::GetComputationClientOrDie()->TransferFromDevice(device_data));
std::vector<at::Tensor> tensors;
for (auto& literal : literals) {
tensors.push_back(MakeTensorFromXlaLiteral(
Expand Down
1 change: 0 additions & 1 deletion test/cpp/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ if [[ "$RUN_CPP_TESTS" == "cpp_tests" ]]; then
#"test_xla_backend_intf"
"test_xla_sharding"
"test_runtime"
"test_status"
"test_status_dont_show_cpp_error_context"
"test_status_show_cpp_error_context")
fi
Expand Down
6 changes: 3 additions & 3 deletions test/cpp/test_replication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/runtime.h"
#include "torch_xla/csrc/status.h"
#include "torch_xla/csrc/tensor_util.h"
#include "torch_xla/csrc/thread_pool.h"
#include "torch_xla/csrc/torch_util.h"
Expand Down Expand Up @@ -78,9 +79,8 @@ void TestSingleReplication(
counter.Wait();

for (size_t i = 0; i < results.size(); ++i) {
std::vector<xla::Literal> literals =
torch_xla::runtime::GetComputationClientOrDie()->TransferFromDevice(
results[i]);
std::vector<xla::Literal> literals = GetValueOrThrow(
runtime::GetComputationClientOrDie()->TransferFromDevice(results[i]));
ASSERT_EQ(literals.size(), 1);

// The result must be the original tensor value, multiplied by the number of
Expand Down
60 changes: 0 additions & 60 deletions test/cpp/test_status.cpp

This file was deleted.

166 changes: 143 additions & 23 deletions test/cpp/test_status_dont_show_cpp_error_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,62 +5,181 @@
#include "torch_xla/csrc/runtime/env_vars.h"
#include "torch_xla/csrc/status.h"

// Reminder
// ========
//
// This file is a companion to test_status_show_cpp_error_context.cpp.
// This file specifically tests behavior when XLA_SHOW_CPP_ERROR_CONTEXT is
// set to "false".
//
// If you add or delete a test in this file, please make the corresponding
// change in test_status_show_cpp_error_context.cpp as well, adapting for
// XLA_SHOW_CPP_ERROR_CONTEXT being "true" in that file.

namespace torch_xla {
namespace {

using absl::Status;
using absl::StatusCode;
using absl::StatusOr;
using absl::StrCat;

constexpr char new_message[] = "New test error message";
constexpr char message[] = "Test error message";
constexpr char test_file[] = "test_file.cpp";
constexpr int32_t line = 42;

TEST(StatusWithoutErrorContextTest, MaybeThrowWithOkStatus) {
Status ok_status = absl::OkStatus();
EXPECT_NO_THROW(MaybeThrow(ok_status));
}

TEST(StatusWithoutErrorContextTest, MaybeThrowWithErrorStatus) {
Status error_status = absl::InvalidArgumentError(message);
EXPECT_THROW(MaybeThrow(error_status), std::runtime_error);
}

TEST(StatusWithoutErrorContextTest, GetValueOrThrowWithOkStatusOr) {
int value = 42;
StatusOr<int> status_or = value;
int result = GetValueOrThrow(std::move(status_or));
EXPECT_EQ(result, value);
}

TEST(StatusWithoutErrorContextTest, GetValueOrThrowWithErrorStatusOr) {
StatusOr<int> status_or = absl::InvalidArgumentError(message);
EXPECT_THROW(GetValueOrThrow(std::move(status_or)), std::runtime_error);
}

TEST(StatusWithoutErrorContextTest, MaybeWithLocationRetunsSameStatus) {
absl::Status error_status = absl::InvalidArgumentError("Test error message");
absl::Status result = MaybeWithLocation(error_status, "test_file.cpp", 42);
Status error_status = absl::InvalidArgumentError(message);
Status result = MaybeWithLocation(error_status, test_file, line);
EXPECT_EQ(result, error_status);
}

TEST(StatusWithoutErrorContextTest, MaybeWithNewMessageEmptyNewMessage) {
absl::Status error_status = absl::InvalidArgumentError("Original error");
absl::Status result = MaybeWithNewMessage(error_status, "test_file.cpp", 42);
Status error_status = absl::InvalidArgumentError(message);
Status result = MaybeWithNewMessage(error_status, test_file, line);
EXPECT_EQ(result, error_status);
}

TEST(StatusWithoutErrorContextTest, MaybeWithNewMessageNonEmptyNewMessage) {
constexpr char new_err_string[] = "New error message";
absl::Status error_status = absl::InvalidArgumentError("Original error");
absl::Status result =
MaybeWithNewMessage(error_status, "test_file.cpp", 42, new_err_string);
Status error_status = absl::InvalidArgumentError(message);
Status result =
MaybeWithNewMessage(error_status, test_file, line, new_message);

ASSERT_FALSE(result.ok());
ASSERT_NE(result, error_status);
ASSERT_FALSE(result.ok());
EXPECT_EQ(result.code(), error_status.code());
EXPECT_EQ(result.message(), new_err_string);
EXPECT_EQ(result.message(), std::string_view(new_message));
}

TEST(StatusWithoutErrorContextTest, MacroReturnIfErrorWithError) {
constexpr char err_string[] = "Test error";
TEST(StatusWithoutErrorContextTest, MacroReturnIfError) {
int value = 42;

auto test_function = [=]() -> StatusOr<int> {
Status ok_status = absl::OkStatus();
XLA_RETURN_IF_ERROR(ok_status);
return value;
};

auto test_function = [=]() -> absl::Status {
absl::Status error_status = absl::InvalidArgumentError(err_string);
StatusOr<int> result = test_function();
ASSERT_TRUE(result.ok());
EXPECT_EQ(result.value(), value);
}

TEST(StatusWithoutErrorContextTest, MacroReturnIfErrorWithError) {
auto test_function = [=]() -> Status {
Status error_status = absl::InvalidArgumentError(message);
XLA_RETURN_IF_ERROR(error_status);
return absl::OkStatus();
};

absl::Status result = test_function();
Status result = test_function();
ASSERT_FALSE(result.ok());
EXPECT_EQ(result.code(), absl::StatusCode::kInvalidArgument);
EXPECT_EQ(result.message(), err_string);
EXPECT_EQ(result.code(), StatusCode::kInvalidArgument);
EXPECT_EQ(result.message(), std::string_view(message));
}

TEST(StatusWithErrorContextTest, MacroReturnIfErrorWithNestedError) {
auto inner_test_function = []() -> Status {
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(message));
};

auto test_function = [&]() -> Status {
XLA_RETURN_IF_ERROR(inner_test_function());
return absl::OkStatus();
};

auto outer_test_function = [&]() -> Status {
XLA_RETURN_IF_ERROR(test_function());
return absl::OkStatus();
};

Status result = outer_test_function();
ASSERT_FALSE(result.ok());
EXPECT_EQ(result.code(), StatusCode::kInvalidArgument);
EXPECT_EQ(result.message(), std::string_view(message));
}

TEST(StatusWithoutErrorContextTest, MacroReturnIfErrorWithErrorWithNewMessage) {
auto test_function = [=]() -> Status {
Status error_status = absl::InvalidArgumentError(message);
XLA_RETURN_IF_ERROR(error_status, new_message);
return absl::OkStatus();
};

Status result = test_function();
ASSERT_FALSE(result.ok());
EXPECT_EQ(result.code(), StatusCode::kInvalidArgument);
EXPECT_EQ(result.message(), std::string_view(new_message));
}

TEST(StatusWithoutErrorContextTest, MacroAssignOrReturn) {
int initial_value = 42;
int expected_value = initial_value * 2;

auto test_function = [=]() -> StatusOr<int> {
StatusOr<int> status_or = initial_value;
XLA_ASSIGN_OR_RETURN(int value, status_or);
return value * 2;
};

StatusOr<int> result = test_function();
ASSERT_TRUE(result.ok());
EXPECT_EQ(result.value(), expected_value);
}

TEST(StatusWithoutErrorContextTest, MacroAssignOrReturnWithError) {
auto test_function = []() -> absl::StatusOr<int> {
absl::StatusOr<int> status_or = absl::InvalidArgumentError("Test error");
auto test_function = []() -> StatusOr<int> {
StatusOr<int> status_or = absl::InvalidArgumentError(message);
XLA_ASSIGN_OR_RETURN(int value, status_or);
return value * 2;
};

absl::StatusOr<int> result = test_function();
StatusOr<int> result = test_function();
ASSERT_FALSE(result.ok());
EXPECT_EQ(result.status().code(), StatusCode::kInvalidArgument);
EXPECT_EQ(result.status().message(), std::string_view(message));
}

TEST(StatusWithoutErrorContextTest,
MacroAssignOrReturnWithErrorWithNewMessage) {
auto test_function = []() -> StatusOr<int> {
StatusOr<int> status_or = absl::InvalidArgumentError(message);
XLA_ASSIGN_OR_RETURN(int value, status_or, new_message);
return value * 2;
};

StatusOr<int> result = test_function();
ASSERT_FALSE(result.ok());
EXPECT_EQ(result.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_EQ(result.status().code(), StatusCode::kInvalidArgument);
EXPECT_EQ(result.status().message(), std::string_view(new_message));
}

TEST(StatusWithoutErrorContextTest, MacroErrorWithLocation) {
absl::Status error_status = absl::InvalidArgumentError("Test error");
absl::Status result = XLA_ERROR_WITH_LOCATION(error_status);
Status error_status = absl::InvalidArgumentError(message);
Status result = XLA_ERROR_WITH_LOCATION(error_status);
EXPECT_EQ(result, error_status);
}

Expand All @@ -69,6 +188,7 @@ void SetUp() {
/* replace= */ 1);
}

} // namespace
} // namespace torch_xla

int main(int argc, char **argv) {
Expand Down
Loading
Loading