diff --git a/test/cpp/cpp_test_util.cpp b/test/cpp/cpp_test_util.cpp index 03efa420719..afe573101eb 100644 --- a/test/cpp/cpp_test_util.cpp +++ b/test/cpp/cpp_test_util.cpp @@ -303,9 +303,8 @@ std::vector Execute( std::vector Fetch( absl::Span device_data) { - std::vector literals = - torch_xla::runtime::GetComputationClientOrDie()->TransferFromDevice( - device_data); + std::vector literals = GetValueOrThrow( + runtime::GetComputationClientOrDie()->TransferFromDevice(device_data)); std::vector tensors; for (auto& literal : literals) { tensors.push_back(MakeTensorFromXlaLiteral( diff --git a/test/cpp/test_replication.cpp b/test/cpp/test_replication.cpp index 88175c2fdbb..b565dc44cd0 100644 --- a/test/cpp/test_replication.cpp +++ b/test/cpp/test_replication.cpp @@ -79,9 +79,8 @@ void TestSingleReplication( counter.Wait(); for (size_t i = 0; i < results.size(); ++i) { - std::vector literals = - torch_xla::runtime::GetComputationClientOrDie()->TransferFromDevice( - results[i]); + std::vector literals = GetValueOrThrow( + runtime::GetComputationClientOrDie()->TransferFromDevice(results[i])); ASSERT_EQ(literals.size(), 1); // The result must be the original tensor value, multiplied by the number of diff --git a/test/test_operations.py b/test/test_operations.py index f037ad4b8cb..68aa0b6c2c8 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -88,6 +88,11 @@ def skipIfFunctionalizationDisabled(reason): return _skipIfFunctionalization(value=True, reason=reason) +def onlyOnCPU(fn): + accelerator = os.environ.get("PJRT_DEVICE").lower() + return unittest.skipIf(accelerator != "cpu", "PJRT_DEVICE=CUDA required")(fn) + + def onlyOnCUDA(fn): accelerator = os.environ.get("PJRT_DEVICE").lower() return unittest.skipIf(accelerator != "cuda", "PJRT_DEVICE=CUDA required")(fn) @@ -2458,6 +2463,16 @@ def test_add_broadcast_error(self): torch.add(a, b) torch_xla.sync() + @onlyOnCPU + def test_construct_large_tensor_raises_error(self): + with self.assertRaisesRegex(RuntimeError, + r"Out of memory allocating \d+ bytes"): + # When eager-mode is enabled, OOM is triggered here. + a = torch.rand(1024, 1024, 1024, 1024, 1024, device=torch_xla.device()) + b = a.sum() + # OOM is raised when we try to bring data from the device. + b.cpu() + class MNISTComparator(nn.Module): diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index ce55969d693..5b62d95efd5 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1229,9 +1229,9 @@ class PyLoweringContext { lowering_ctx.GetParametersData(); // Fetch this parameter data - std::vector literals = + std::vector literals = GetValueOrThrow( runtime::GetComputationClientOrDie()->TransferFromDevice( - UnwrapXlaData(device_data)); + UnwrapXlaData(device_data))); // Create a mapping from paramater id to the tensor data std::unordered_map results; diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index d9832971890..c4760783f4d 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -123,6 +123,7 @@ cc_library( ":tf_logging", ":xla_coordinator", "//torch_xla/csrc:status", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index c7603c8932a..c2f9389a4a0 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -318,7 +318,7 @@ class ComputationClient { // Note: `TransferFromDevice` call will block until the `DataPtrs` are ready // if they were created by `TransferToDevice` or `Execute*`. Calling this from // python while holding the GIL can cause deadlocks! - virtual std::vector TransferFromDevice( + virtual absl::StatusOr> TransferFromDevice( absl::Span handles) = 0; virtual std::uintptr_t UnsafeBufferPointer(const DataPtr handle) = 0; diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cpp b/torch_xla/csrc/runtime/ifrt_computation_client.cpp index a463f79a226..f5a6af1b267 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cpp +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cpp @@ -436,8 +436,8 @@ std::shared_ptr IfrtComputationClient::GetPjRtBuffer( XLA_ERROR() << __FUNCTION__ << " not implemented"; } -std::vector IfrtComputationClient::TransferFromDevice( - absl::Span handles) { +absl::StatusOr> +IfrtComputationClient::TransferFromDevice(absl::Span handles) { metrics::TimedSection timed(TransferFromDeviceMetric()); tsl::profiler::TraceMe activity("IfrtComputationClient::TransferFromDevice", tsl::profiler::TraceMeLevel::kInfo); @@ -455,9 +455,9 @@ std::vector IfrtComputationClient::TransferFromDevice( auto& literal = literals.emplace_back( xla::ShapeUtil::DeviceShapeToHostShape(ifrt_data->shape())); std::vector byte_strides(literal.shape().dimensions_size()); - XLA_CHECK_OK(xla::ShapeUtil::ByteStrides(literal.shape(), - absl::MakeSpan(byte_strides))); - XLA_CHECK_OK( + XLA_RETURN_IF_ERROR(xla::ShapeUtil::ByteStrides( + literal.shape(), absl::MakeSpan(byte_strides))); + XLA_RETURN_IF_ERROR( replicated_array ->CopyToHostBuffer(literal.untyped_data(), byte_strides, xla::ifrt::ArrayCopySemantics::kAlwaysCopy) diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index 9c21d7a8d7f..46b6343dc10 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -62,7 +62,7 @@ class IfrtComputationClient : public ComputationClient { XLA_ERROR() << __FUNCTION__ << " not implemented"; } - std::vector TransferFromDevice( + absl::StatusOr> TransferFromDevice( absl::Span handles) override; std::uintptr_t UnsafeBufferPointer(const DataPtr handle) override; diff --git a/torch_xla/csrc/runtime/ifrt_computation_client_test.cpp b/torch_xla/csrc/runtime/ifrt_computation_client_test.cpp index 7a4741fc1bc..eb39f9b2e23 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client_test.cpp +++ b/torch_xla/csrc/runtime/ifrt_computation_client_test.cpp @@ -70,7 +70,7 @@ TEST(PjRtComputationClientTest, Init) { // Copy the output from device back to host and assert correctness.. ASSERT_EQ(results.size(), 1); - auto result_literals = client->TransferFromDevice(results); + auto result_literals = GetValueOrThrow(client->TransferFromDevice(results)); ASSERT_THAT(result_literals, ::testing::SizeIs(1)); EXPECT_TRUE(xla::LiteralTestUtil::Equal( xla::LiteralUtil::CreateR2({{6.0f, 8.0f}, {10.0f, 12.0f}}), diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cpp b/torch_xla/csrc/runtime/pjrt_computation_client.cpp index 8239da35846..dd4950d87f5 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cpp +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cpp @@ -4,6 +4,7 @@ #include #include +#include "absl/log/absl_check.h" #include "absl/strings/ascii.h" #include "absl/synchronization/blocking_counter.h" #include "absl/types/span.h" @@ -508,8 +509,8 @@ std::shared_ptr PjRtComputationClient::GetPjRtBuffer( } } -std::vector PjRtComputationClient::TransferFromDevice( - absl::Span handles) { +absl::StatusOr> +PjRtComputationClient::TransferFromDevice(absl::Span handles) { metrics::TimedSection timed(TransferFromDeviceMetric()); tsl::profiler::TraceMe activity("PjRtComputationClient::TransferFromDevice", tsl::profiler::TraceMeLevel::kInfo); @@ -522,21 +523,17 @@ std::vector PjRtComputationClient::TransferFromDevice( // Use XLA replication to reassemble the sharded data. If input handle // is not sharded, then it is a no-op. std::shared_ptr pjrt_data = ReplicateShardedData(handle); - XLA_CHECK(pjrt_data) << "PjRt_data is null in " << __FUNCTION__; - XLA_CHECK(pjrt_data->buffer != nullptr) + ABSL_CHECK(pjrt_data) << "PjRt_data is null in " << __FUNCTION__; + ABSL_CHECK(pjrt_data->buffer != nullptr) << "PjRt buffer is null in " << __FUNCTION__; - xla::Literal& literal = - literals.emplace_back(host_output_shape(pjrt_data->buffer.get())); + xla::Literal& literal = literals.emplace_back( + xla::Literal(host_output_shape(pjrt_data->buffer.get()))); futures.push_back(pjrt_data->buffer->ToLiteral(&literal)); total_size += literal.size_bytes(); } - for (auto& future : futures) { - absl::Status status = future.Await(); - XLA_CHECK_OK(status) << "Failed to await future from buffer to literal in" - << __FUNCTION__; - } + XLA_RETURN_IF_ERROR(xla::JoinFutures(futures).Await()); InboundDataMetric()->AddSample(total_size); return literals; @@ -773,10 +770,8 @@ PjRtComputationClient::ExecuteComputation( std::optional> returned_future; std::vector> results = - pjrt_computation.executable - ->ExecuteSharded(buffers, pjrt_device, execute_options, - returned_future) - .value(); + GetValueOrThrow(pjrt_computation.executable->ExecuteSharded( + buffers, pjrt_device, execute_options, returned_future)); returned_future->OnReady(std::move( [timed, op_tracker = std::move(op_tracker)](absl::Status unused) mutable { @@ -878,10 +873,8 @@ PjRtComputationClient::ExecuteReplicated( tsl::profiler::TraceMe activity( "PjRtComputationClient::ExecuteReplicated_execute", tsl::profiler::TraceMeLevel::kInfo); - results = pjrt_computation.executable - ->Execute(std::move(argument_handles), execute_options, - returned_futures) - .value(); + results = GetValueOrThrow(pjrt_computation.executable->Execute( + std::move(argument_handles), execute_options, returned_futures)); (*returned_futures)[0].OnReady( std::move([timed, op_tracker = std::move(op_tracker)]( diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index b7c61e2ec74..3a6b4478f72 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -65,7 +65,7 @@ class PjRtComputationClient : public ComputationClient { absl::Span handles, absl::Span shardings) override; - std::vector TransferFromDevice( + absl::StatusOr> TransferFromDevice( absl::Span handles) override; std::uintptr_t UnsafeBufferPointer(const DataPtr handle) override; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client_test.cpp b/torch_xla/csrc/runtime/pjrt_computation_client_test.cpp index 3398e61a278..0fe2b2a70fc 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client_test.cpp +++ b/torch_xla/csrc/runtime/pjrt_computation_client_test.cpp @@ -120,7 +120,7 @@ TEST_F(PjRtComputationClientTest, Init) { // Copy the output from device back to host and assert correctness. ASSERT_EQ(results.size(), 1); - auto result_literals = client_->TransferFromDevice(results); + auto result_literals = GetValueOrThrow(client_->TransferFromDevice(results)); ASSERT_THAT(result_literals, ::testing::SizeIs(1)); EXPECT_TRUE(xla::LiteralTestUtil::Equal( xla::LiteralUtil::CreateR2({{6.0f, 8.0f}, {10.0f, 12.0f}}), diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index 0a7f184cda7..e2cd3a025f5 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -24,6 +24,7 @@ #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/tf_logging.h" #include "torch_xla/csrc/runtime/util.h" +#include "torch_xla/csrc/status.h" #include "torch_xla/csrc/thread_pool.h" #include "torch_xla/csrc/torch_util.h" #include "torch_xla/csrc/xla_backend_impl.h" @@ -909,8 +910,8 @@ std::vector ReleaseGilAndTransferData( save = PyEval_SaveThread(); } std::vector literals = - runtime::GetComputationClientOrDie()->TransferFromDevice( - UnwrapXlaData(xla_data)); + GetValueOrThrow(runtime::GetComputationClientOrDie()->TransferFromDevice( + UnwrapXlaData(xla_data))); if (save) { PyEval_RestoreThread(save); }