Skip to content

Commit 1818f47

Browse files
committed
fix
1 parent 2d4b3b9 commit 1818f47

File tree

5 files changed

+50
-16
lines changed

5 files changed

+50
-16
lines changed

.bazelrc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ build --define=grpc_no_ares=true
2222

2323
build -c opt
2424

25+
build --copt=-fexceptions
26+
2527
build --config=short_logs
2628

2729
# Force GCC because clang/bazel has issues.

bazel/rules_def.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def ptxla_cc_test(
2626
linkstatic = True,
2727
copts = copts + [
2828
"-isystemexternal/torch", # Required for system includes.
29-
"-fexceptions", # Required for testing crashes.
29+
# "-fexceptions", # Required for testing crashes.
3030
],
3131
deps = deps + [
3232
"@pybind11//:pybind11_embed", # libpython

torch_xla/csrc/runtime/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,7 @@ ptxla_cc_test(
502502
"@xla//xla/tests:literal_test_util",
503503
"@xla//xla/tools:hlo_module_loader",
504504
],
505+
timeout = "short",
505506
)
506507

507508
# ptxla_cc_test(

torch_xla/csrc/runtime/pjrt_computation_client.cc

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -633,13 +633,25 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
633633
mlir::ModuleOp mlir_module =
634634
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
635635
ConvertHloToStableHlo(instance.computation.mutable_proto(), &mlir_module);
636-
maybe_executable = client_->CompileAndLoad(mlir_module, compile_options);
636+
try {
637+
maybe_executable =
638+
client_->CompileAndLoad(mlir_module, compile_options);
639+
} catch (const absl::BadStatusOrAccess& e) {
640+
LOG(ERROR) << e.what();
641+
throw std::invalid_argument(e.what());
642+
}
637643
StableHloCompileCounter()->AddValue(1);
638644
} else {
639-
maybe_executable =
640-
client_->CompileAndLoad(instance.computation, compile_options);
645+
try {
646+
maybe_executable =
647+
client_->CompileAndLoad(instance.computation, compile_options);
648+
} catch (const absl::BadStatusOrAccess& e) {
649+
LOG(ERROR) << e.what();
650+
throw std::invalid_argument(e.what());
651+
}
641652
}
642653
if (!maybe_executable.ok()) {
654+
LOG(ERROR) << maybe_executable.status().message();
643655
// This will automatically raise a Python ValueError exception.
644656
// See https://pybind11.readthedocs.io/en/stable/advanced/exceptions.html.
645657
throw std::invalid_argument(

torch_xla/csrc/runtime/pjrt_computation_client_test.cc

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,19 +41,38 @@ TEST(PjRtComputationClient, ThrowsExpectedExceptionWhenCompileFails) {
4141
const auto client = std::make_unique<PjRtComputationClient>();
4242
const std::string device = client->GetDefaultDevice();
4343

44-
// Compose a computation with an enormous shape.
45-
const auto shape =
46-
xla::ShapeUtil::MakeShape(xla::F32, {8000000000, 1000000000});
47-
std::vector<ComputationClient::CompileInstance> instances;
48-
instances.push_back(ComputationClient::CompileInstance(
49-
std::move(MakeComputation().value()), device,
50-
client->GetCompilationDevices(device, client->GetLocalDevices()),
51-
&shape));
44+
xla::Shape shape;
45+
try {
46+
// Compose a computation with an enormous shape.
47+
shape = xla::ShapeUtil::MakeShape(xla::F32, {8000000000, 5, 1000000000});
48+
} catch (const std::exception& e) {
49+
LOG(ERROR) << "ZW: " << e.what();
50+
}
51+
52+
shape = xla::Shape(xla::F32, {8000000000, 5, 1000000000},
53+
/*dynamic_dimensions=*/{});
5254

53-
// Compiling the graph should fail, which should throw instead of crashing.
54-
// TODO(https://github.com/pytorch/xla/issues/9096): ensure that
55-
// the exception has type std::invalid_argument.
56-
EXPECT_ANY_THROW(client->Compile(std::move(instances)));
55+
std::vector<ComputationClient::CompileInstance> instances;
56+
try {
57+
instances.push_back(ComputationClient::CompileInstance(
58+
std::move(MakeComputation().value()), device,
59+
client->GetCompilationDevices(device, client->GetLocalDevices()),
60+
&shape));
61+
} catch (const std::exception& e) {
62+
LOG(ERROR) << "ZW: " << e.what();
63+
}
64+
65+
try {
66+
// Compiling the graph should fail, which should throw instead of crashing.
67+
// TODO(https://github.com/pytorch/xla/issues/9096): ensure that
68+
// the exception has type std::invalid_argument.
69+
client->Compile(std::move(instances));
70+
} catch (const std::exception& e) {
71+
LOG(ERROR) << "ZW: " << e.what();
72+
} catch (...) {
73+
LOG(ERROR) << "Exception thrown!";
74+
}
75+
// EXPECT_ANY_THROW(client->Compile(std::move(instances)));
5776
}
5877

5978
TEST(PjRtComputationClientTest, Init) {

0 commit comments

Comments
 (0)