Skip to content

Commit 2d4b3b9

Browse files
committed
Throw a Python exception if compilation fails.
1 parent d1eb12c commit 2d4b3b9

File tree

2 files changed

+35
-7
lines changed

2 files changed

+35
-7
lines changed

torch_xla/csrc/runtime/pjrt_computation_client.cc

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <algorithm>
44
#include <future>
5+
#include <stdexcept>
56
#include <unordered_set>
67
#include <vector>
78

@@ -625,21 +626,26 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
625626
device_assignment);
626627
}
627628

628-
std::unique_ptr<xla::PjRtLoadedExecutable> executable;
629+
absl::StatusOr<std::unique_ptr<xla::PjRtLoadedExecutable>> maybe_executable;
629630
if (runtime::sys_util::GetEnvBool("XLA_STABLEHLO_COMPILE", false)) {
630631
// Convert HLO to StableHLO for PjRt client compilation.
631632
mlir::MLIRContext context;
632633
mlir::ModuleOp mlir_module =
633634
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
634635
ConvertHloToStableHlo(instance.computation.mutable_proto(), &mlir_module);
635-
executable =
636-
client_->CompileAndLoad(mlir_module, compile_options).value();
636+
maybe_executable = client_->CompileAndLoad(mlir_module, compile_options);
637637
StableHloCompileCounter()->AddValue(1);
638638
} else {
639-
executable =
640-
client_->CompileAndLoad(instance.computation, compile_options)
641-
.value();
639+
maybe_executable =
640+
client_->CompileAndLoad(instance.computation, compile_options);
642641
}
642+
if (!maybe_executable.ok()) {
643+
// This will automatically raise a Python ValueError exception.
644+
// See https://pybind11.readthedocs.io/en/stable/advanced/exceptions.html.
645+
throw std::invalid_argument(
646+
std::string(maybe_executable.status().message()));
647+
}
648+
auto executable = std::move(maybe_executable).value();
643649

644650
auto memory_stats_status_or = executable->GetCompiledMemoryStats();
645651
if (memory_stats_status_or.ok()) {

torch_xla/csrc/runtime/pjrt_computation_client_test.cc

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <gtest/gtest.h>
44

55
#include <memory>
6+
#include <stdexcept>
67
#include <string>
78
#include <vector>
89

@@ -34,6 +35,27 @@ absl::StatusOr<xla::XlaComputation> MakeComputation() {
3435
return builder.Build();
3536
}
3637

38+
TEST(PjRtComputationClient, ThrowsExpectedExceptionWhenCompileFails) {
39+
// Get a CPU client.
40+
tsl::setenv("PJRT_DEVICE", "CPU", true);
41+
const auto client = std::make_unique<PjRtComputationClient>();
42+
const std::string device = client->GetDefaultDevice();
43+
44+
// Compose a computation with an enormous shape.
45+
const auto shape =
46+
xla::ShapeUtil::MakeShape(xla::F32, {8000000000, 1000000000});
47+
std::vector<ComputationClient::CompileInstance> instances;
48+
instances.push_back(ComputationClient::CompileInstance(
49+
std::move(MakeComputation().value()), device,
50+
client->GetCompilationDevices(device, client->GetLocalDevices()),
51+
&shape));
52+
53+
// Compiling the graph should fail, which should throw instead of crashing.
54+
// TODO(https://github.com/pytorch/xla/issues/9096): ensure that
55+
// the exception has type std::invalid_argument.
56+
EXPECT_ANY_THROW(client->Compile(std::move(instances)));
57+
}
58+
3759
TEST(PjRtComputationClientTest, Init) {
3860
// Get a CPU client.
3961
tsl::setenv("PJRT_DEVICE", "CPU", true);
@@ -69,7 +91,7 @@ TEST(PjRtComputationClientTest, Init) {
6991
*computations[0], client->TransferToDevice(absl::MakeConstSpan(args)),
7092
device, options);
7193

72-
// Copy the output from device back to host and assert correctness..
94+
// Copy the output from device back to host and assert correctness.
7395
ASSERT_EQ(results.size(), 1);
7496
auto result_literals = client->TransferFromDevice(results);
7597
ASSERT_THAT(result_literals, ::testing::SizeIs(1));

0 commit comments

Comments
 (0)