|
3 | 3 | #include <gtest/gtest.h>
|
4 | 4 |
|
5 | 5 | #include <memory>
|
| 6 | +#include <stdexcept> |
6 | 7 | #include <string>
|
7 | 8 | #include <vector>
|
8 | 9 |
|
@@ -34,6 +35,27 @@ absl::StatusOr<xla::XlaComputation> MakeComputation() {
|
34 | 35 | return builder.Build();
|
35 | 36 | }
|
36 | 37 |
|
| 38 | +TEST(PjRtComputationClient, ThrowsExpectedExceptionWhenCompileFails) { |
| 39 | + // Get a CPU client. |
| 40 | + tsl::setenv("PJRT_DEVICE", "CPU", true); |
| 41 | + const auto client = std::make_unique<PjRtComputationClient>(); |
| 42 | + const std::string device = client->GetDefaultDevice(); |
| 43 | + |
| 44 | + // Compose a computation with an enormous shape. |
| 45 | + const auto shape = |
| 46 | + xla::ShapeUtil::MakeShape(xla::F32, {8000000000, 1000000000}); |
| 47 | + std::vector<ComputationClient::CompileInstance> instances; |
| 48 | + instances.push_back(ComputationClient::CompileInstance( |
| 49 | + std::move(MakeComputation().value()), device, |
| 50 | + client->GetCompilationDevices(device, client->GetLocalDevices()), |
| 51 | + &shape)); |
| 52 | + |
| 53 | + // Compiling the graph should fail, which should throw instead of crashing. |
| 54 | + // TODO(https://github.com/pytorch/xla/issues/9096): ensure that |
| 55 | + // the exception has type std::invalid_argument. |
| 56 | + EXPECT_ANY_THROW(client->Compile(std::move(instances))); |
| 57 | +} |
| 58 | + |
37 | 59 | TEST(PjRtComputationClientTest, Init) {
|
38 | 60 | // Get a CPU client.
|
39 | 61 | tsl::setenv("PJRT_DEVICE", "CPU", true);
|
@@ -69,7 +91,7 @@ TEST(PjRtComputationClientTest, Init) {
|
69 | 91 | *computations[0], client->TransferToDevice(absl::MakeConstSpan(args)),
|
70 | 92 | device, options);
|
71 | 93 |
|
72 |
| - // Copy the output from device back to host and assert correctness.. |
| 94 | + // Copy the output from device back to host and assert correctness. |
73 | 95 | ASSERT_EQ(results.size(), 1);
|
74 | 96 | auto result_literals = client->TransferFromDevice(results);
|
75 | 97 | ASSERT_THAT(result_literals, ::testing::SizeIs(1));
|
|
0 commit comments