2
2
3
3
#include < gtest/gtest.h>
4
4
5
+ #include < exception>
5
6
#include < memory>
6
7
#include < stdexcept>
7
8
#include < string>
24
25
25
26
namespace torch_xla {
26
27
namespace runtime {
28
+ namespace {
27
29
28
- absl::StatusOr<xla::XlaComputation> MakeComputation () {
29
- xla::Shape input_shape =
30
+ // Returns a computation to compute x + y where x and y are both F32[2,2]
31
+ // arrays.
32
+ absl::StatusOr<xla::XlaComputation> MakeAddComputation () {
33
+ const xla::Shape input_shape =
30
34
xla::ShapeUtil::MakeShape (xla::PrimitiveType::F32, {2 , 2 });
31
35
xla::XlaBuilder builder (" AddComputation" );
32
36
xla::XlaOp x = xla::Parameter (&builder, 0 , input_shape, " x" );
@@ -35,43 +39,54 @@ absl::StatusOr<xla::XlaComputation> MakeComputation() {
35
39
return builder.Build ();
36
40
}
37
41
42
+ // Returns a computation to compute the matrix multiplication of two matrices:
43
+ // x: F32[size, 1] mul y: F32[1, size] => z: F32[size, size]
44
+ absl::StatusOr<xla::XlaComputation> MakeMatMulComputation (int64_t size) {
45
+ const xla::Shape x_shape =
46
+ xla::ShapeUtil::MakeShape (xla::PrimitiveType::F32, {size, 1 });
47
+ const xla::Shape y_shape =
48
+ xla::ShapeUtil::MakeShape (xla::PrimitiveType::F32, {1 , size});
49
+ xla::XlaBuilder builder (" MatMulComputation" );
50
+ xla::XlaOp x = xla::Parameter (&builder, 0 , x_shape, " x" );
51
+ xla::XlaOp y = xla::Parameter (&builder, 1 , y_shape, " y" );
52
+ xla::XlaOp matmul = xla::Dot (x, y);
53
+ return builder.Build ();
54
+ }
55
+
38
56
TEST (PjRtComputationClient, ThrowsExpectedExceptionWhenCompileFails) {
39
57
// Get a CPU client.
40
58
tsl::setenv (" PJRT_DEVICE" , " CPU" , true );
41
59
const auto client = std::make_unique<PjRtComputationClient>();
42
60
const std::string device = client->GetDefaultDevice ();
43
61
44
- xla::Shape shape;
45
- try {
46
- // Compose a computation with an enormous shape.
47
- shape = xla::ShapeUtil::MakeShape (xla::F32, {8000000000 , 5 , 1000000000 });
48
- } catch (const std::exception & e) {
49
- LOG (ERROR) << " ZW: " << e.what ();
50
- }
51
-
52
- shape = xla::Shape (xla::F32, {8000000000 , 5 , 1000000000 },
53
- /* dynamic_dimensions=*/ {});
54
-
62
+ // Compose a computation to multiply two matrices.
63
+ const int64_t size = 2L *1000000000 ;
64
+ xla::Shape out_shape (xla::F32, {size, size},
65
+ /* dynamic_dimensions=*/ {});
55
66
std::vector<ComputationClient::CompileInstance> instances;
56
67
try {
57
- instances.push_back (ComputationClient::CompileInstance (
58
- std::move (MakeComputation ( ).value ()), device,
59
- client->GetCompilationDevices (device, client->GetLocalDevices ()),
60
- &shape ));
68
+ instances.push_back (ComputationClient::CompileInstance (
69
+ std::move (MakeMatMulComputation (size ).value ()), device,
70
+ client->GetCompilationDevices (device, client->GetLocalDevices ()),
71
+ &out_shape ));
61
72
} catch (const std::exception & e) {
62
- LOG (ERROR) << " ZW: " << e.what ();
73
+ LOG (ERROR) << " ZW1: " << e.what ();
74
+ } catch (...) {
75
+ LOG (ERROR) << " ZW1: Exception thrown!" ;
63
76
}
64
77
78
+ LOG (ERROR) << " ZW1: done" ;
65
79
try {
66
80
// Compiling the graph should fail, which should throw instead of crashing.
67
81
// TODO(https://github.com/pytorch/xla/issues/9096): ensure that
68
82
// the exception has type std::invalid_argument.
69
83
client->Compile (std::move (instances));
70
84
} catch (const std::exception & e) {
71
- LOG (ERROR) << " ZW : " << e.what ();
85
+ LOG (ERROR) << " ZW2 : " << e.what ();
72
86
} catch (...) {
73
- LOG (ERROR) << " Exception thrown!" ;
87
+ LOG (ERROR) << " ZW2: Exception thrown!" ;
74
88
}
89
+ LOG (ERROR) << " ZW2: done" ;
75
90
// EXPECT_ANY_THROW(client->Compile(std::move(instances)));
76
91
}
77
92
@@ -81,13 +96,13 @@ TEST(PjRtComputationClientTest, Init) {
81
96
auto client = std::make_unique<PjRtComputationClient>();
82
97
std::string device = client->GetDefaultDevice ();
83
98
84
- // Compose a computation.
85
- auto shape = xla::ShapeUtil::MakeShape (xla::F32, {2 , 2 });
99
+ // Compose a computation to add two 2x2 matrices .
100
+ auto out_shape = xla::ShapeUtil::MakeShape (xla::F32, {2 , 2 });
86
101
std::vector<ComputationClient::CompileInstance> instances;
87
102
instances.push_back (ComputationClient::CompileInstance (
88
- std::move (MakeComputation ().value ()), device,
103
+ std::move (MakeAddComputation ().value ()), device,
89
104
client->GetCompilationDevices (device, client->GetLocalDevices ()),
90
- &shape ));
105
+ &out_shape ));
91
106
92
107
// Prepare inputs.
93
108
xla::Literal literal_x =
@@ -119,5 +134,6 @@ TEST(PjRtComputationClientTest, Init) {
119
134
result_literals[0 ]));
120
135
}
121
136
137
+ } // namespace
122
138
} // namespace runtime
123
139
} // namespace torch_xla
0 commit comments