Skip to content

Commit 1e09ef1

Browse files
committed
Replace GetValueOrThrow with status propagation in ReleaseGilAndTransferData
Modify `ReleaseGilAndTransferData` function to use proper status propagation instead of `GetValueOrThrow` with `GetComputationClientOrDie`. This improves error handling by allowing status types to be propagated up the call stack rather than immediately throwing exceptions. Changes: - Update function signature to return `absl::StatusOr<std::vector<xla::Literal>>` - Replace `GetComputationClientOrDie()` with `GetComputationClient()` - Use `XLA_ASSIGN_OR_RETURN` macros for both client acquisition and `TransferFromDevice` - Update callers in tensor_util.cpp and xla_graph_executor.cpp to handle `StatusOr<T>` This follows the status propagation patterns used elsewhere in the codebase and aligns with the examples in pjrt_registry.cpp.
1 parent 048459c commit 1e09ef1

File tree

3 files changed

+12
-7
lines changed

3 files changed

+12
-7
lines changed

torch_xla/csrc/tensor_util.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -896,7 +896,7 @@ xla::Literal GetTensorLiteral(const at::Tensor& tensor, const xla::Shape* shape,
896896
return literal;
897897
}
898898

899-
std::vector<xla::Literal> ReleaseGilAndTransferData(
899+
absl::StatusOr<std::vector<xla::Literal>> ReleaseGilAndTransferData(
900900
absl::Span<const torch::lazy::BackendDataPtr> xla_data) {
901901
// HACK: This method may be called outside of python (mainly in C++ tests) or
902902
// when the GIL is already released, so we must check both cases here. If
@@ -909,9 +909,12 @@ std::vector<xla::Literal> ReleaseGilAndTransferData(
909909
if (release_gil && Py_IsInitialized() && PyGILState_Check()) {
910910
save = PyEval_SaveThread();
911911
}
912-
std::vector<xla::Literal> literals =
913-
GetValueOrThrow(runtime::GetComputationClientOrDie()->TransferFromDevice(
914-
UnwrapXlaData(xla_data)));
912+
913+
XLA_ASSIGN_OR_RETURN(runtime::ComputationClient * client,
914+
runtime::GetComputationClient());
915+
XLA_ASSIGN_OR_RETURN(std::vector<xla::Literal> literals,
916+
client->TransferFromDevice(UnwrapXlaData(xla_data)));
917+
915918
if (save) {
916919
PyEval_RestoreThread(save);
917920
}
@@ -922,7 +925,8 @@ std::vector<xla::Literal> ReleaseGilAndTransferData(
922925
std::vector<at::Tensor> XlaDataToTensors(
923926
absl::Span<const torch::lazy::BackendDataPtr> xla_data,
924927
absl::Span<const at::ScalarType> dest_element_type) {
925-
std::vector<xla::Literal> literals = ReleaseGilAndTransferData(xla_data);
928+
std::vector<xla::Literal> literals =
929+
GetValueOrThrow(ReleaseGilAndTransferData(xla_data));
926930
std::vector<at::Tensor> tensors(literals.size());
927931
absl::BlockingCounter counter(literals.size());
928932
for (size_t i = 0; i < tensors.size(); ++i) {

torch_xla/csrc/tensor_util.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ at::Tensor MakeTensorFromXlaLiteral(const xla::Literal& literal,
2828
// Execution and data transfer are async in PJRT, so TransferFromDevice may
2929
// block until `DataPtr`s are ready. Release the GIL so other threads can
3030
// proceed and unblock any transfers or collective computations.
31-
std::vector<xla::Literal> ReleaseGilAndTransferData(
31+
absl::StatusOr<std::vector<xla::Literal>> ReleaseGilAndTransferData(
3232
absl::Span<const torch::lazy::BackendDataPtr> xla_data);
3333

3434
// TODO LTC @wonjoo - Migrate to upstream after Device -> BackendDevice

torch_xla/csrc/xla_graph_executor.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,8 @@ std::vector<at::Tensor> XLAGraphExecutor::GetTensors(
496496
async != nullptr ? async->tensors_data
497497
: absl::Span<const torch::lazy::BackendDataPtr>());
498498

499-
std::vector<xla::Literal> literals = ReleaseGilAndTransferData(tensors_data);
499+
std::vector<xla::Literal> literals =
500+
GetValueOrThrow(ReleaseGilAndTransferData(tensors_data));
500501

501502
return FetchTensors(tensors, literals,
502503
async != nullptr ? &async->indices : nullptr);

0 commit comments

Comments
 (0)