Skip to content

Commit a370c5d

Browse files
committed
Propagate status on OOM crashes and exception.
1 parent 58592c4 commit a370c5d

13 files changed

+44
-39
lines changed

test/cpp/cpp_test_util.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -303,9 +303,8 @@ std::vector<torch_xla::runtime::ComputationClient::DataPtr> Execute(
303303
std::vector<at::Tensor> Fetch(
304304
absl::Span<const torch_xla::runtime::ComputationClient::DataPtr>
305305
device_data) {
306-
std::vector<xla::Literal> literals =
307-
torch_xla::runtime::GetComputationClientOrDie()->TransferFromDevice(
308-
device_data);
306+
std::vector<xla::Literal> literals = GetValueOrThrow(
307+
runtime::GetComputationClientOrDie()->TransferFromDevice(device_data));
309308
std::vector<at::Tensor> tensors;
310309
for (auto& literal : literals) {
311310
tensors.push_back(MakeTensorFromXlaLiteral(

test/cpp/test_replication.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,8 @@ void TestSingleReplication(
7979
counter.Wait();
8080

8181
for (size_t i = 0; i < results.size(); ++i) {
82-
std::vector<xla::Literal> literals =
83-
torch_xla::runtime::GetComputationClientOrDie()->TransferFromDevice(
84-
results[i]);
82+
std::vector<xla::Literal> literals = GetValueOrThrow(
83+
runtime::GetComputationClientOrDie()->TransferFromDevice(results[i]));
8584
ASSERT_EQ(literals.size(), 1);
8685

8786
// The result must be the original tensor value, multiplied by the number of

test/test_operations.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2458,6 +2458,14 @@ def test_add_broadcast_error(self):
24582458
torch.add(a, b)
24592459
torch_xla.sync()
24602460

2461+
def test_construct_large_tensor_raises_error(self):
2462+
a = torch.rand(1024, 1024, 1024, 1024, 1024, device=torch_xla.device())
2463+
2464+
# OOM is raised when we try to bring data from the device.
2465+
with self.assertRaisesRegex(RuntimeError, r"Out of memory allocating \d* bytes"):
2466+
b = a.sum()
2467+
b.cpu()
2468+
24612469

24622470
class MNISTComparator(nn.Module):
24632471

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,9 +1229,9 @@ class PyLoweringContext {
12291229
lowering_ctx.GetParametersData();
12301230

12311231
// Fetch this parameter data
1232-
std::vector<xla::Literal> literals =
1232+
std::vector<xla::Literal> literals = GetValueOrThrow(
12331233
runtime::GetComputationClientOrDie()->TransferFromDevice(
1234-
UnwrapXlaData(device_data));
1234+
UnwrapXlaData(device_data)));
12351235

12361236
// Create a mapping from paramater id to the tensor data
12371237
std::unordered_map<int64_t, at::Tensor> results;

torch_xla/csrc/runtime/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ cc_library(
123123
":tf_logging",
124124
":xla_coordinator",
125125
"//torch_xla/csrc:status",
126+
"@com_google_absl//absl/log:absl_check",
126127
"@com_google_absl//absl/strings",
127128
"@com_google_absl//absl/synchronization",
128129
"@com_google_absl//absl/types:span",

torch_xla/csrc/runtime/computation_client.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ class ComputationClient {
318318
// Note: `TransferFromDevice` call will block until the `DataPtrs` are ready
319319
// if they were created by `TransferToDevice` or `Execute*`. Calling this from
320320
// python while holding the GIL can cause deadlocks!
321-
virtual std::vector<xla::Literal> TransferFromDevice(
321+
virtual absl::StatusOr<std::vector<xla::Literal>> TransferFromDevice(
322322
absl::Span<const DataPtr> handles) = 0;
323323

324324
virtual std::uintptr_t UnsafeBufferPointer(const DataPtr handle) = 0;

torch_xla/csrc/runtime/ifrt_computation_client.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -436,8 +436,8 @@ std::shared_ptr<xla::PjRtBuffer> IfrtComputationClient::GetPjRtBuffer(
436436
XLA_ERROR() << __FUNCTION__ << " not implemented";
437437
}
438438

439-
std::vector<xla::Literal> IfrtComputationClient::TransferFromDevice(
440-
absl::Span<const DataPtr> handles) {
439+
absl::StatusOr<std::vector<xla::Literal>>
440+
IfrtComputationClient::TransferFromDevice(absl::Span<const DataPtr> handles) {
441441
metrics::TimedSection timed(TransferFromDeviceMetric());
442442
tsl::profiler::TraceMe activity("IfrtComputationClient::TransferFromDevice",
443443
tsl::profiler::TraceMeLevel::kInfo);
@@ -455,9 +455,9 @@ std::vector<xla::Literal> IfrtComputationClient::TransferFromDevice(
455455
auto& literal = literals.emplace_back(
456456
xla::ShapeUtil::DeviceShapeToHostShape(ifrt_data->shape()));
457457
std::vector<int64_t> byte_strides(literal.shape().dimensions_size());
458-
XLA_CHECK_OK(xla::ShapeUtil::ByteStrides(literal.shape(),
459-
absl::MakeSpan(byte_strides)));
460-
XLA_CHECK_OK(
458+
XLA_RETURN_IF_ERROR(xla::ShapeUtil::ByteStrides(
459+
literal.shape(), absl::MakeSpan(byte_strides)));
460+
XLA_RETURN_IF_ERROR(
461461
replicated_array
462462
->CopyToHostBuffer(literal.untyped_data(), byte_strides,
463463
xla::ifrt::ArrayCopySemantics::kAlwaysCopy)

torch_xla/csrc/runtime/ifrt_computation_client.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class IfrtComputationClient : public ComputationClient {
6262
XLA_ERROR() << __FUNCTION__ << " not implemented";
6363
}
6464

65-
std::vector<xla::Literal> TransferFromDevice(
65+
absl::StatusOr<std::vector<xla::Literal>> TransferFromDevice(
6666
absl::Span<const DataPtr> handles) override;
6767

6868
std::uintptr_t UnsafeBufferPointer(const DataPtr handle) override;

torch_xla/csrc/runtime/ifrt_computation_client_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ TEST(PjRtComputationClientTest, Init) {
7070

7171
// Copy the output from device back to host and assert correctness..
7272
ASSERT_EQ(results.size(), 1);
73-
auto result_literals = client->TransferFromDevice(results);
73+
auto result_literals = GetValueOrThrow(client->TransferFromDevice(results));
7474
ASSERT_THAT(result_literals, ::testing::SizeIs(1));
7575
EXPECT_TRUE(xla::LiteralTestUtil::Equal(
7676
xla::LiteralUtil::CreateR2<float>({{6.0f, 8.0f}, {10.0f, 12.0f}}),

torch_xla/csrc/runtime/pjrt_computation_client.cpp

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <stdexcept>
55
#include <vector>
66

7+
#include "absl/log/absl_check.h"
78
#include "absl/strings/ascii.h"
89
#include "absl/synchronization/blocking_counter.h"
910
#include "absl/types/span.h"
@@ -508,8 +509,8 @@ std::shared_ptr<xla::PjRtBuffer> PjRtComputationClient::GetPjRtBuffer(
508509
}
509510
}
510511

511-
std::vector<xla::Literal> PjRtComputationClient::TransferFromDevice(
512-
absl::Span<const DataPtr> handles) {
512+
absl::StatusOr<std::vector<xla::Literal>>
513+
PjRtComputationClient::TransferFromDevice(absl::Span<const DataPtr> handles) {
513514
metrics::TimedSection timed(TransferFromDeviceMetric());
514515
tsl::profiler::TraceMe activity("PjRtComputationClient::TransferFromDevice",
515516
tsl::profiler::TraceMeLevel::kInfo);
@@ -522,21 +523,21 @@ std::vector<xla::Literal> PjRtComputationClient::TransferFromDevice(
522523
// Use XLA replication to reassemble the sharded data. If input handle
523524
// is not sharded, then it is a no-op.
524525
std::shared_ptr<PjRtData> pjrt_data = ReplicateShardedData(handle);
525-
XLA_CHECK(pjrt_data) << "PjRt_data is null in " << __FUNCTION__;
526-
XLA_CHECK(pjrt_data->buffer != nullptr)
526+
ABSL_CHECK(pjrt_data) << "PjRt_data is null in " << __FUNCTION__;
527+
ABSL_CHECK(pjrt_data->buffer != nullptr)
527528
<< "PjRt buffer is null in " << __FUNCTION__;
528529

529-
xla::Literal& literal =
530-
literals.emplace_back(host_output_shape(pjrt_data->buffer.get()));
530+
// Constructing a literal too large will make the whole program crash.
531+
// Instead, we pass allocate_arrays=False, which makes this kind of
532+
// error possible to be handled in the `Await()` call below.
533+
xla::Literal& literal = literals.emplace_back(
534+
xla::Literal(host_output_shape(pjrt_data->buffer.get()),
535+
/* allocate_arrays= */ false));
531536
futures.push_back(pjrt_data->buffer->ToLiteral(&literal));
532537

533538
total_size += literal.size_bytes();
534539
}
535-
for (auto& future : futures) {
536-
absl::Status status = future.Await();
537-
XLA_CHECK_OK(status) << "Failed to await future from buffer to literal in"
538-
<< __FUNCTION__;
539-
}
540+
XLA_RETURN_IF_ERROR(xla::JoinFutures(futures).Await());
540541
InboundDataMetric()->AddSample(total_size);
541542

542543
return literals;
@@ -773,10 +774,8 @@ PjRtComputationClient::ExecuteComputation(
773774

774775
std::optional<xla::PjRtFuture<>> returned_future;
775776
std::vector<std::unique_ptr<xla::PjRtBuffer>> results =
776-
pjrt_computation.executable
777-
->ExecuteSharded(buffers, pjrt_device, execute_options,
778-
returned_future)
779-
.value();
777+
GetValueOrThrow(pjrt_computation.executable->ExecuteSharded(
778+
buffers, pjrt_device, execute_options, returned_future));
780779

781780
returned_future->OnReady(std::move(
782781
[timed, op_tracker = std::move(op_tracker)](absl::Status unused) mutable {
@@ -878,10 +877,8 @@ PjRtComputationClient::ExecuteReplicated(
878877
tsl::profiler::TraceMe activity(
879878
"PjRtComputationClient::ExecuteReplicated_execute",
880879
tsl::profiler::TraceMeLevel::kInfo);
881-
results = pjrt_computation.executable
882-
->Execute(std::move(argument_handles), execute_options,
883-
returned_futures)
884-
.value();
880+
results = GetValueOrThrow(pjrt_computation.executable->Execute(
881+
std::move(argument_handles), execute_options, returned_futures));
885882

886883
(*returned_futures)[0].OnReady(
887884
std::move([timed, op_tracker = std::move(op_tracker)](

0 commit comments

Comments
 (0)