diff --git a/test/cpp/test_xla_sharding.cpp b/test/cpp/test_xla_sharding.cpp index 3d276f2dc26..b179c6e523c 100644 --- a/test/cpp/test_xla_sharding.cpp +++ b/test/cpp/test_xla_sharding.cpp @@ -29,7 +29,7 @@ bool XlaDataValuesEqual(torch::lazy::BackendDataPtr a, torch::lazy::BackendDataPtr b, at::ScalarType element_type) { std::vector tensors = - XlaDataToTensors({a, b}, {element_type, element_type}); + GetValueOrThrow(XlaDataToTensors({a, b}, {element_type, element_type})); return TensorCompare(tensors[0], tensors[1]); } } // namespace diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 5b62d95efd5..8873fb434e0 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -2712,7 +2712,7 @@ void InitXlaModuleBindings(py::module m) { } std::vector cpu_shards = - XlaDataToTensors(WrapXlaData(handles), element_types); + GetValueOrThrow(XlaDataToTensors(WrapXlaData(handles), element_types)); // Populate the resulting vector of shards and device strings std::vector>> result; int shards_per_tensor = diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 1a1a7737ccf..6459293a87f 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -40,6 +40,7 @@ #include "torch_xla/csrc/runtime/pjrt_computation_client.h" #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/xla_util.h" +#include "torch_xla/csrc/status.h" #include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/torch_util.h" #include "torch_xla/csrc/xla_graph_executor.h" @@ -512,7 +513,7 @@ at::Tensor XLATensor::ToTensor(bool detached) { // The GetXlaData() call will trigger an ApplyPendingGraph() if an IR // XlaNode is available on the tensor. std::vector tensors = - XlaDataToTensors({GetXlaData()}, {dtype()}); + GetValueOrThrow(XlaDataToTensors({GetXlaData()}, {dtype()})); tensor = std::move(tensors.front()); if (!detached) { SetTensorData(tensor); diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index e2cd3a025f5..26c669b1e4f 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -896,7 +896,7 @@ xla::Literal GetTensorLiteral(const at::Tensor& tensor, const xla::Shape* shape, return literal; } -std::vector ReleaseGilAndTransferData( +absl::StatusOr> ReleaseGilAndTransferData( absl::Span xla_data) { // HACK: This method may be called outside of python (mainly in C++ tests) or // when the GIL is already released, so we must check both cases here. If @@ -909,9 +909,12 @@ std::vector ReleaseGilAndTransferData( if (release_gil && Py_IsInitialized() && PyGILState_Check()) { save = PyEval_SaveThread(); } - std::vector literals = - GetValueOrThrow(runtime::GetComputationClientOrDie()->TransferFromDevice( - UnwrapXlaData(xla_data))); + + XLA_ASSIGN_OR_RETURN(runtime::ComputationClient * client, + runtime::GetComputationClient()); + XLA_ASSIGN_OR_RETURN(std::vector literals, + client->TransferFromDevice(UnwrapXlaData(xla_data))); + if (save) { PyEval_RestoreThread(save); } @@ -919,10 +922,11 @@ std::vector ReleaseGilAndTransferData( return literals; } -std::vector XlaDataToTensors( +absl::StatusOr> XlaDataToTensors( absl::Span xla_data, absl::Span dest_element_type) { - std::vector literals = ReleaseGilAndTransferData(xla_data); + XLA_ASSIGN_OR_RETURN(std::vector literals, + ReleaseGilAndTransferData(xla_data)); std::vector tensors(literals.size()); absl::BlockingCounter counter(literals.size()); for (size_t i = 0; i < tensors.size(); ++i) { diff --git a/torch_xla/csrc/tensor_util.h b/torch_xla/csrc/tensor_util.h index 0804d3e9f78..a0f6dea480f 100644 --- a/torch_xla/csrc/tensor_util.h +++ b/torch_xla/csrc/tensor_util.h @@ -28,11 +28,11 @@ at::Tensor MakeTensorFromXlaLiteral(const xla::Literal& literal, // Execution and data transfer are async in PJRT, so TransferFromDevice may // block until `DataPtr`s are ready. Release the GIL so other threads can // proceed and unblock any transfers or collective computations. -std::vector ReleaseGilAndTransferData( +absl::StatusOr> ReleaseGilAndTransferData( absl::Span xla_data); // TODO LTC @wonjoo - Migrate to upstream after Device -> BackendDevice -std::vector XlaDataToTensors( +absl::StatusOr> XlaDataToTensors( absl::Span xla_data, absl::Span dest_element_type); diff --git a/torch_xla/csrc/xla_backend_impl.cpp b/torch_xla/csrc/xla_backend_impl.cpp index bf130e1fab7..df52770b11e 100644 --- a/torch_xla/csrc/xla_backend_impl.cpp +++ b/torch_xla/csrc/xla_backend_impl.cpp @@ -10,6 +10,8 @@ #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/runtime.h" +#include "torch_xla/csrc/status.h" +#include "torch_xla/csrc/tensor_util.h" namespace at { // This function is defined in the codegenerated RegisterDispatchKey.cpp file. @@ -92,7 +94,7 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface { const torch::lazy::BackendDataPtr data, std::optional logical_scalar_type) const override { // TODO(JackCaoG): handle the logical_scalar_type == nullptr case - return XlaDataToTensors({data}, {*logical_scalar_type})[0]; + return GetValueOrThrow(XlaDataToTensors({data}, {*logical_scalar_type}))[0]; } std::unique_ptr CreateLoweringContext( diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 65eee78bc02..0931578047e 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -497,7 +497,8 @@ std::vector XLAGraphExecutor::GetTensors( async != nullptr ? async->tensors_data : absl::Span()); - std::vector literals = ReleaseGilAndTransferData(tensors_data); + std::vector literals = + GetValueOrThrow(ReleaseGilAndTransferData(tensors_data)); return FetchTensors(tensors, literals, async != nullptr ? &async->indices : nullptr);