Skip to content

Error Handling: refactor ComputationClient::TransferFromDevice to propagate status. #9429

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions test/cpp/cpp_test_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,8 @@ std::vector<torch_xla::runtime::ComputationClient::DataPtr> Execute(
std::vector<at::Tensor> Fetch(
absl::Span<const torch_xla::runtime::ComputationClient::DataPtr>
device_data) {
std::vector<xla::Literal> literals =
torch_xla::runtime::GetComputationClientOrDie()->TransferFromDevice(
device_data);
std::vector<xla::Literal> literals = GetValueOrThrow(
runtime::GetComputationClientOrDie()->TransferFromDevice(device_data));
std::vector<at::Tensor> tensors;
for (auto& literal : literals) {
tensors.push_back(MakeTensorFromXlaLiteral(
Expand Down
5 changes: 2 additions & 3 deletions test/cpp/test_replication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,8 @@ void TestSingleReplication(
counter.Wait();

for (size_t i = 0; i < results.size(); ++i) {
std::vector<xla::Literal> literals =
torch_xla::runtime::GetComputationClientOrDie()->TransferFromDevice(
results[i]);
std::vector<xla::Literal> literals = GetValueOrThrow(
runtime::GetComputationClientOrDie()->TransferFromDevice(results[i]));
ASSERT_EQ(literals.size(), 1);

// The result must be the original tensor value, multiplied by the number of
Expand Down
14 changes: 14 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ def skipIfFunctionalizationDisabled(reason):
return _skipIfFunctionalization(value=True, reason=reason)


def onlyOnCPU(fn):
accelerator = os.environ.get("PJRT_DEVICE").lower()
return unittest.skipIf(accelerator != "cpu", "PJRT_DEVICE=CUDA required")(fn)

def onlyOnCUDA(fn):
accelerator = os.environ.get("PJRT_DEVICE").lower()
return unittest.skipIf(accelerator != "cuda", "PJRT_DEVICE=CUDA required")(fn)
Expand Down Expand Up @@ -2458,6 +2462,16 @@ def test_add_broadcast_error(self):
torch.add(a, b)
torch_xla.sync()

@onlyOnCPU
def test_construct_large_tensor_raises_error(self):
with self.assertRaisesRegex(RuntimeError,
r"Out of memory allocating \d+ bytes"):
# When eager-mode is enabled, OOM is triggered here.
a = torch.rand(1024, 1024, 1024, 1024, 1024, device=torch_xla.device())
b = a.sum()
# OOM is raised when we try to bring data from the device.
b.cpu()


class MNISTComparator(nn.Module):

Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1229,9 +1229,9 @@ class PyLoweringContext {
lowering_ctx.GetParametersData();

// Fetch this parameter data
std::vector<xla::Literal> literals =
std::vector<xla::Literal> literals = GetValueOrThrow(
runtime::GetComputationClientOrDie()->TransferFromDevice(
UnwrapXlaData(device_data));
UnwrapXlaData(device_data)));

// Create a mapping from paramater id to the tensor data
std::unordered_map<int64_t, at::Tensor> results;
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ cc_library(
":tf_logging",
":xla_coordinator",
"//torch_xla/csrc:status",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ class ComputationClient {
// Note: `TransferFromDevice` call will block until the `DataPtrs` are ready
// if they were created by `TransferToDevice` or `Execute*`. Calling this from
// python while holding the GIL can cause deadlocks!
virtual std::vector<xla::Literal> TransferFromDevice(
virtual absl::StatusOr<std::vector<xla::Literal>> TransferFromDevice(
absl::Span<const DataPtr> handles) = 0;

virtual std::uintptr_t UnsafeBufferPointer(const DataPtr handle) = 0;
Expand Down
10 changes: 5 additions & 5 deletions torch_xla/csrc/runtime/ifrt_computation_client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -436,8 +436,8 @@ std::shared_ptr<xla::PjRtBuffer> IfrtComputationClient::GetPjRtBuffer(
XLA_ERROR() << __FUNCTION__ << " not implemented";
}

std::vector<xla::Literal> IfrtComputationClient::TransferFromDevice(
absl::Span<const DataPtr> handles) {
absl::StatusOr<std::vector<xla::Literal>>
IfrtComputationClient::TransferFromDevice(absl::Span<const DataPtr> handles) {
metrics::TimedSection timed(TransferFromDeviceMetric());
tsl::profiler::TraceMe activity("IfrtComputationClient::TransferFromDevice",
tsl::profiler::TraceMeLevel::kInfo);
Expand All @@ -455,9 +455,9 @@ std::vector<xla::Literal> IfrtComputationClient::TransferFromDevice(
auto& literal = literals.emplace_back(
xla::ShapeUtil::DeviceShapeToHostShape(ifrt_data->shape()));
std::vector<int64_t> byte_strides(literal.shape().dimensions_size());
XLA_CHECK_OK(xla::ShapeUtil::ByteStrides(literal.shape(),
absl::MakeSpan(byte_strides)));
XLA_CHECK_OK(
XLA_RETURN_IF_ERROR(xla::ShapeUtil::ByteStrides(
literal.shape(), absl::MakeSpan(byte_strides)));
XLA_RETURN_IF_ERROR(
replicated_array
->CopyToHostBuffer(literal.untyped_data(), byte_strides,
xla::ifrt::ArrayCopySemantics::kAlwaysCopy)
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/ifrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class IfrtComputationClient : public ComputationClient {
XLA_ERROR() << __FUNCTION__ << " not implemented";
}

std::vector<xla::Literal> TransferFromDevice(
absl::StatusOr<std::vector<xla::Literal>> TransferFromDevice(
absl::Span<const DataPtr> handles) override;

std::uintptr_t UnsafeBufferPointer(const DataPtr handle) override;
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/ifrt_computation_client_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ TEST(PjRtComputationClientTest, Init) {

// Copy the output from device back to host and assert correctness..
ASSERT_EQ(results.size(), 1);
auto result_literals = client->TransferFromDevice(results);
auto result_literals = GetValueOrThrow(client->TransferFromDevice(results));
ASSERT_THAT(result_literals, ::testing::SizeIs(1));
EXPECT_TRUE(xla::LiteralTestUtil::Equal(
xla::LiteralUtil::CreateR2<float>({{6.0f, 8.0f}, {10.0f, 12.0f}}),
Expand Down
31 changes: 12 additions & 19 deletions torch_xla/csrc/runtime/pjrt_computation_client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <stdexcept>
#include <vector>

#include "absl/log/absl_check.h"
#include "absl/strings/ascii.h"
#include "absl/synchronization/blocking_counter.h"
#include "absl/types/span.h"
Expand Down Expand Up @@ -508,8 +509,8 @@ std::shared_ptr<xla::PjRtBuffer> PjRtComputationClient::GetPjRtBuffer(
}
}

std::vector<xla::Literal> PjRtComputationClient::TransferFromDevice(
absl::Span<const DataPtr> handles) {
absl::StatusOr<std::vector<xla::Literal>>
PjRtComputationClient::TransferFromDevice(absl::Span<const DataPtr> handles) {
metrics::TimedSection timed(TransferFromDeviceMetric());
tsl::profiler::TraceMe activity("PjRtComputationClient::TransferFromDevice",
tsl::profiler::TraceMeLevel::kInfo);
Expand All @@ -522,21 +523,17 @@ std::vector<xla::Literal> PjRtComputationClient::TransferFromDevice(
// Use XLA replication to reassemble the sharded data. If input handle
// is not sharded, then it is a no-op.
std::shared_ptr<PjRtData> pjrt_data = ReplicateShardedData(handle);
XLA_CHECK(pjrt_data) << "PjRt_data is null in " << __FUNCTION__;
XLA_CHECK(pjrt_data->buffer != nullptr)
ABSL_CHECK(pjrt_data) << "PjRt_data is null in " << __FUNCTION__;
ABSL_CHECK(pjrt_data->buffer != nullptr)
<< "PjRt buffer is null in " << __FUNCTION__;

xla::Literal& literal =
literals.emplace_back(host_output_shape(pjrt_data->buffer.get()));
xla::Literal& literal = literals.emplace_back(
xla::Literal(host_output_shape(pjrt_data->buffer.get())));
futures.push_back(pjrt_data->buffer->ToLiteral(&literal));

total_size += literal.size_bytes();
}
for (auto& future : futures) {
absl::Status status = future.Await();
XLA_CHECK_OK(status) << "Failed to await future from buffer to literal in"
<< __FUNCTION__;
}
XLA_RETURN_IF_ERROR(xla::JoinFutures(futures).Await());
InboundDataMetric()->AddSample(total_size);

return literals;
Expand Down Expand Up @@ -773,10 +770,8 @@ PjRtComputationClient::ExecuteComputation(

std::optional<xla::PjRtFuture<>> returned_future;
std::vector<std::unique_ptr<xla::PjRtBuffer>> results =
pjrt_computation.executable
->ExecuteSharded(buffers, pjrt_device, execute_options,
returned_future)
.value();
GetValueOrThrow(pjrt_computation.executable->ExecuteSharded(
buffers, pjrt_device, execute_options, returned_future));

returned_future->OnReady(std::move(
[timed, op_tracker = std::move(op_tracker)](absl::Status unused) mutable {
Expand Down Expand Up @@ -878,10 +873,8 @@ PjRtComputationClient::ExecuteReplicated(
tsl::profiler::TraceMe activity(
"PjRtComputationClient::ExecuteReplicated_execute",
tsl::profiler::TraceMeLevel::kInfo);
results = pjrt_computation.executable
->Execute(std::move(argument_handles), execute_options,
returned_futures)
.value();
results = GetValueOrThrow(pjrt_computation.executable->Execute(
std::move(argument_handles), execute_options, returned_futures));

(*returned_futures)[0].OnReady(
std::move([timed, op_tracker = std::move(op_tracker)](
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class PjRtComputationClient : public ComputationClient {
absl::Span<const DataPtr> handles,
absl::Span<const xla::OpSharding> shardings) override;

std::vector<xla::Literal> TransferFromDevice(
absl::StatusOr<std::vector<xla::Literal>> TransferFromDevice(
absl::Span<const DataPtr> handles) override;

std::uintptr_t UnsafeBufferPointer(const DataPtr handle) override;
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/pjrt_computation_client_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ TEST_F(PjRtComputationClientTest, Init) {

// Copy the output from device back to host and assert correctness.
ASSERT_EQ(results.size(), 1);
auto result_literals = client_->TransferFromDevice(results);
auto result_literals = GetValueOrThrow(client_->TransferFromDevice(results));
ASSERT_THAT(result_literals, ::testing::SizeIs(1));
EXPECT_TRUE(xla::LiteralTestUtil::Equal(
xla::LiteralUtil::CreateR2<float>({{6.0f, 8.0f}, {10.0f, 12.0f}}),
Expand Down
5 changes: 3 additions & 2 deletions torch_xla/csrc/tensor_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/runtime/tf_logging.h"
#include "torch_xla/csrc/runtime/util.h"
#include "torch_xla/csrc/status.h"
#include "torch_xla/csrc/thread_pool.h"
#include "torch_xla/csrc/torch_util.h"
#include "torch_xla/csrc/xla_backend_impl.h"
Expand Down Expand Up @@ -909,8 +910,8 @@ std::vector<xla::Literal> ReleaseGilAndTransferData(
save = PyEval_SaveThread();
}
std::vector<xla::Literal> literals =
runtime::GetComputationClientOrDie()->TransferFromDevice(
UnwrapXlaData(xla_data));
GetValueOrThrow(runtime::GetComputationClientOrDie()->TransferFromDevice(
UnwrapXlaData(xla_data)));
if (save) {
PyEval_RestoreThread(save);
}
Expand Down
Loading