Skip to content

Error Handling: refactor ComputationClient::TransferFromDevice to propagate status. #9429

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from

Conversation

ysiraichi
Copy link
Collaborator

@ysiraichi ysiraichi commented Jul 1, 2025

This PR makes 2 main changes in order to standardize and improve error handling:

  • ComputationClient::TransferFromDevice returns a StatusOr<T> instance
  • Wrap xla::PjRtLoadedExecutable::Execute(Sharded) with GetValueOrThrow

These changes mainly affect the errors whenever an OOM occurs. The second one targets eager mode. As an example, the following is the result of running the file below (without eager mode):

device = torch_xla.device()
a = torch.rand(1024, 1024, 1024, 1024, 1024, device=device)
b = a.sum()
print(b)

Before this PR:

F0000 00:00:1751368998.014824    2835 pjrt_computation_client.cpp:525] Non-OK-status: status
Status: INTERNAL: Error preparing computation: Out of memory allocating 4503599761588224 bytes.
*** Begin stack trace ***
        tsl::CurrentStackTrace[abi:cxx11]()
        torch_xla::runtime::PjRtComputationClient::TransferFromDevice(absl::lts_20230802::Span<std::shared_ptr<torch_xla::runtime::ComputationClient::Data> const>)
        ...
        _start
*** End stack trace ***
Failed to await future from buffer to literal inTransferFromDevice
*** Check failure stack trace: ***
    @     0x7ddc923438f9  absl::lts_20230802::log_internal::LogMessage::PrepareToDie()
    ...
Aborted (core dumped)

After this PR:

Traceback (most recent call last):
  File "examples/mem.py", line 11, in <module>
    print(b)
  File "torch/_tensor.py", line 590, in __repr__
    return torch._tensor_str._str(self, tensor_contents=tensor_contents)
  File "torch/_tensor_str.py", line 726, in _str
    return _str_intern(self, tensor_contents=tensor_contents)
  File "torch/_tensor_str.py", line 462, in _str_intern
    self = self.to("cpu")
RuntimeError: Error preparing computation: Out of memory allocating 4503599627370496 bytes.

(with XLA_SHOW_CPP_ERROR_CONTEXT=1)

RuntimeError: Error preparing computation: Out of memory allocating 4503599627370496 bytes. (at torch_xla/csrc/runtime/pjrt_computation_client.cpp:524)

@ysiraichi

This comment was marked as outdated.

@ysiraichi ysiraichi force-pushed the ysiraichi/status-for-oom-errors branch 3 times, most recently from cef8c1e to 247fdf5 Compare July 1, 2025 16:29
@ysiraichi ysiraichi force-pushed the ysiraichi/status-for-oom-errors branch 2 times, most recently from b390a61 to 821c384 Compare July 1, 2025 18:15
@ysiraichi ysiraichi changed the base branch from ysiraichi/status-qol-functions to master July 1, 2025 18:16
@ysiraichi ysiraichi marked this pull request as ready for review July 1, 2025 18:20
Copy link
Collaborator

@zhanyong-wan zhanyong-wan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!

Can you add a python test to ensure that OOM does result in a python exception with the expected error message as opposed to crashing?

@ysiraichi

This comment was marked as outdated.

@ysiraichi ysiraichi force-pushed the ysiraichi/status-for-oom-errors branch from 048459c to a370c5d Compare July 17, 2025 23:13
@ysiraichi ysiraichi marked this pull request as ready for review July 19, 2025 18:06
@ysiraichi ysiraichi requested a review from zhanyong-wan July 19, 2025 18:06
Copy link
Collaborator

@zhanyong-wan zhanyong-wan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!

@@ -2458,6 +2458,15 @@ def test_add_broadcast_error(self):
torch.add(a, b)
torch_xla.sync()

def test_construct_large_tensor_raises_error(self):
with self.assertRaisesRegex(RuntimeError,
r"Out of memory allocating \d* bytes"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

* => +

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants