Skip to content

Commit 9770c7a

Browse files
committed
fix
1 parent 73bda3b commit 9770c7a

File tree

1 file changed

+40
-24
lines changed

1 file changed

+40
-24
lines changed

torch_xla/csrc/runtime/pjrt_computation_client_test.cpp

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <gtest/gtest.h>
44

5+
#include <exception>
56
#include <memory>
67
#include <stdexcept>
78
#include <string>
@@ -24,9 +25,12 @@
2425

2526
namespace torch_xla {
2627
namespace runtime {
28+
namespace {
2729

28-
absl::StatusOr<xla::XlaComputation> MakeComputation() {
29-
xla::Shape input_shape =
30+
// Returns a computation to compute x + y where x and y are both F32[2,2]
31+
// arrays.
32+
absl::StatusOr<xla::XlaComputation> MakeAddComputation() {
33+
const xla::Shape input_shape =
3034
xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {2, 2});
3135
xla::XlaBuilder builder("AddComputation");
3236
xla::XlaOp x = xla::Parameter(&builder, 0, input_shape, "x");
@@ -35,43 +39,54 @@ absl::StatusOr<xla::XlaComputation> MakeComputation() {
3539
return builder.Build();
3640
}
3741

42+
// Returns a computation to compute the matrix multiplication of two matrices:
43+
// x: F32[size, 1] mul y: F32[1, size] => z: F32[size, size]
44+
absl::StatusOr<xla::XlaComputation> MakeMatMulComputation(int64_t size) {
45+
const xla::Shape x_shape =
46+
xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {size, 1});
47+
const xla::Shape y_shape =
48+
xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {1, size});
49+
xla::XlaBuilder builder("MatMulComputation");
50+
xla::XlaOp x = xla::Parameter(&builder, 0, x_shape, "x");
51+
xla::XlaOp y = xla::Parameter(&builder, 1, y_shape, "y");
52+
xla::XlaOp matmul = xla::Dot(x, y);
53+
return builder.Build();
54+
}
55+
3856
TEST(PjRtComputationClient, ThrowsExpectedExceptionWhenCompileFails) {
3957
// Get a CPU client.
4058
tsl::setenv("PJRT_DEVICE", "CPU", true);
4159
const auto client = std::make_unique<PjRtComputationClient>();
4260
const std::string device = client->GetDefaultDevice();
4361

44-
xla::Shape shape;
45-
try {
46-
// Compose a computation with an enormous shape.
47-
shape = xla::ShapeUtil::MakeShape(xla::F32, {8000000000, 5, 1000000000});
48-
} catch (const std::exception& e) {
49-
LOG(ERROR) << "ZW: " << e.what();
50-
}
51-
52-
shape = xla::Shape(xla::F32, {8000000000, 5, 1000000000},
53-
/*dynamic_dimensions=*/{});
54-
62+
// Compose a computation to multiply two matrices.
63+
const int64_t size = 2L*1000000000;
64+
xla::Shape out_shape(xla::F32, {size, size},
65+
/*dynamic_dimensions=*/{});
5566
std::vector<ComputationClient::CompileInstance> instances;
5667
try {
57-
instances.push_back(ComputationClient::CompileInstance(
58-
std::move(MakeComputation().value()), device,
59-
client->GetCompilationDevices(device, client->GetLocalDevices()),
60-
&shape));
68+
instances.push_back(ComputationClient::CompileInstance(
69+
std::move(MakeMatMulComputation(size).value()), device,
70+
client->GetCompilationDevices(device, client->GetLocalDevices()),
71+
&out_shape));
6172
} catch (const std::exception& e) {
62-
LOG(ERROR) << "ZW: " << e.what();
73+
LOG(ERROR) << "ZW1: " << e.what();
74+
} catch (...) {
75+
LOG(ERROR) << "ZW1: Exception thrown!";
6376
}
6477

78+
LOG(ERROR) << "ZW1: done";
6579
try {
6680
// Compiling the graph should fail, which should throw instead of crashing.
6781
// TODO(https://github.com/pytorch/xla/issues/9096): ensure that
6882
// the exception has type std::invalid_argument.
6983
client->Compile(std::move(instances));
7084
} catch (const std::exception& e) {
71-
LOG(ERROR) << "ZW: " << e.what();
85+
LOG(ERROR) << "ZW2: " << e.what();
7286
} catch (...) {
73-
LOG(ERROR) << "Exception thrown!";
87+
LOG(ERROR) << "ZW2: Exception thrown!";
7488
}
89+
LOG(ERROR) << "ZW2: done";
7590
// EXPECT_ANY_THROW(client->Compile(std::move(instances)));
7691
}
7792

@@ -81,13 +96,13 @@ TEST(PjRtComputationClientTest, Init) {
8196
auto client = std::make_unique<PjRtComputationClient>();
8297
std::string device = client->GetDefaultDevice();
8398

84-
// Compose a computation.
85-
auto shape = xla::ShapeUtil::MakeShape(xla::F32, {2, 2});
99+
// Compose a computation to add two 2x2 matrices.
100+
auto out_shape = xla::ShapeUtil::MakeShape(xla::F32, {2, 2});
86101
std::vector<ComputationClient::CompileInstance> instances;
87102
instances.push_back(ComputationClient::CompileInstance(
88-
std::move(MakeComputation().value()), device,
103+
std::move(MakeAddComputation().value()), device,
89104
client->GetCompilationDevices(device, client->GetLocalDevices()),
90-
&shape));
105+
&out_shape));
91106

92107
// Prepare inputs.
93108
xla::Literal literal_x =
@@ -119,5 +134,6 @@ TEST(PjRtComputationClientTest, Init) {
119134
result_literals[0]));
120135
}
121136

137+
} // namespace
122138
} // namespace runtime
123139
} // namespace torch_xla

0 commit comments

Comments
 (0)