diff --git a/test/cpp/cpp_test_util.cpp b/test/cpp/cpp_test_util.cpp index afe573101eb..6731f5800fd 100644 --- a/test/cpp/cpp_test_util.cpp +++ b/test/cpp/cpp_test_util.cpp @@ -295,9 +295,11 @@ std::vector Execute( std::move(instances)); torch_xla::runtime::ComputationClient::ExecuteComputationOptions options; - return torch_xla::runtime::GetComputationClientOrDie()->ExecuteComputation( - *computations.front(), UnwrapXlaData(lowering_ctx.GetParametersData()), - device.toString(), options); + return GetValueOrThrow( + torch_xla::runtime::GetComputationClientOrDie()->ExecuteComputation( + *computations.front(), + UnwrapXlaData(lowering_ctx.GetParametersData()), device.toString(), + options)); } std::vector Fetch( diff --git a/test/cpp/test_replication.cpp b/test/cpp/test_replication.cpp index b565dc44cd0..386f9db3a9a 100644 --- a/test/cpp/test_replication.cpp +++ b/test/cpp/test_replication.cpp @@ -65,13 +65,13 @@ void TestSingleReplication( torch_xla::runtime::ComputationClient::ExecuteComputationOptions exec_options; for (size_t i = 0; i < device_strings.size(); ++i) { auto executor = [&, i]() { - results[i] = + results[i] = GetValueOrThrow( torch_xla::runtime::GetComputationClientOrDie()->ExecuteComputation( *compiled_computations[i], {std::dynamic_pointer_cast< torch_xla::runtime::ComputationClient::Data>( tensors_data[i])}, - device_strings[i], exec_options); + device_strings[i], exec_options)); counter.DecrementCount(); }; torch_xla::thread::Schedule(std::move(executor)); diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index c2f9389a4a0..c5b550fb684 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -16,6 +16,7 @@ #include #include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "torch_xla/csrc/device.h" @@ -346,7 +347,7 @@ class ComputationClient { // The passed device must match the common device of the arguments Data. // If options.explode_tuple is true, the output tuple will be decomposed into // its single elements. - virtual std::vector ExecuteComputation( + virtual absl::StatusOr> ExecuteComputation( const Computation& computation, absl::Span arguments, const std::string& device, const ExecuteComputationOptions& options = @@ -357,7 +358,7 @@ class ComputationClient { // as `devices`. If options.explode_tuple is true, the output tuples will be // decomposed into their single elements. Returns a vector of outputs, each // of which is sharded in the same order as `devices`. - virtual std::vector ExecuteReplicated( + virtual absl::StatusOr> ExecuteReplicated( const Computation& computation, absl::Span arguments, absl::Span devices, const ExecuteReplicatedOptions& options) = 0; diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cpp b/torch_xla/csrc/runtime/ifrt_computation_client.cpp index f5a6af1b267..5538cb4a5e2 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cpp +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cpp @@ -4,6 +4,7 @@ #include #include +#include "absl/log/absl_check.h" #include "absl/strings/ascii.h" #include "absl/synchronization/blocking_counter.h" #include "absl/types/span.h" @@ -416,8 +417,8 @@ tsl::RCReference IfrtComputationClient::ReplicateShardedData( torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions execute_options; - auto sharded_results = ExecuteReplicated(*computations.front(), {{handle}}, - GetLocalDevices(), execute_options); + auto sharded_results = GetValueOrThrow(ExecuteReplicated( + *computations.front(), {{handle}}, GetLocalDevices(), execute_options)); auto replicated_output = std::dynamic_pointer_cast(sharded_results[0]) ->buffer->FullyReplicatedShard( @@ -537,16 +538,16 @@ std::vector IfrtComputationClient::Compile( return computations; } -std::vector +absl::StatusOr> IfrtComputationClient::ExecuteComputation( const ComputationClient::Computation& computation, absl::Span arguments, const std::string& device, const ExecuteComputationOptions& options) { // TODO: Implement sharded exec in IFRT - XLA_ERROR() << __FUNCTION__ << " not implemented"; + return absl::UnimplementedError("ExecuteComputation not implemented"); } -std::vector +absl::StatusOr> IfrtComputationClient::ExecuteReplicated( const ComputationClient::Computation& computation, const absl::Span arguments, @@ -591,11 +592,10 @@ IfrtComputationClient::ExecuteReplicated( TF_VLOG(5) << "ExecuteReplicated acquiring IFRT device lock for " << spmd_device_str << " Done"; - xla::ifrt::LoadedExecutable::ExecuteResult result = - ifrt_computation.executable - ->Execute(absl::MakeSpan(argument_handles), execute_options, - std::nullopt) - .value(); + XLA_ASSIGN_OR_RETURN( + xla::ifrt::LoadedExecutable::ExecuteResult result, + ifrt_computation.executable->Execute(absl::MakeSpan(argument_handles), + execute_options, std::nullopt)); result.status.OnReady(std::move([timed, op_tracker = std::move(op_tracker)]( absl::Status status) mutable { @@ -612,7 +612,7 @@ IfrtComputationClient::ExecuteReplicated( ? *ifrt_computation.output_shardings_ : std::vector(outputs.size(), xla::HloSharding::Replicate().ToProto()); - XLA_CHECK_EQ(output_shardings.size(), outputs.size()); + ABSL_CHECK_EQ(output_shardings.size(), outputs.size()); std::vector data_handles(outputs.size()); { diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index 46b6343dc10..ab24d1ae357 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -78,12 +78,12 @@ class IfrtComputationClient : public ComputationClient { std::vector Compile( std::vector instances) override; - std::vector ExecuteComputation( + absl::StatusOr> ExecuteComputation( const Computation& computation, absl::Span arguments, const std::string& device, const ExecuteComputationOptions& options) override; - std::vector ExecuteReplicated( + absl::StatusOr> ExecuteReplicated( const Computation& computation, const absl::Span arguments, absl::Span devices, const ExecuteReplicatedOptions& options) override; diff --git a/torch_xla/csrc/runtime/ifrt_computation_client_test.cpp b/torch_xla/csrc/runtime/ifrt_computation_client_test.cpp index eb39f9b2e23..d48b4337d21 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client_test.cpp +++ b/torch_xla/csrc/runtime/ifrt_computation_client_test.cpp @@ -64,9 +64,10 @@ TEST(PjRtComputationClientTest, Init) { std::make_shared(std::move(literal_y), device)}; // Execute the graph. - std::vector results = client->ExecuteReplicated( - *computations[0], client->TransferToDevice(absl::MakeConstSpan(args)), - {device}, options); + std::vector results = + GetValueOrThrow(client->ExecuteReplicated( + *computations[0], client->TransferToDevice(absl::MakeConstSpan(args)), + {device}, options)); // Copy the output from device back to host and assert correctness.. ASSERT_EQ(results.size(), 1); diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cpp b/torch_xla/csrc/runtime/pjrt_computation_client.cpp index dd4950d87f5..d57dbf9be6c 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cpp +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cpp @@ -387,8 +387,8 @@ PjRtComputationClient::ReplicateShardedData( torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions execute_options; auto sharded_results = - ExecuteReplicated(*computations.front(), {sharded_data}, - GetLocalDevices(), execute_options); + GetValueOrThrow(ExecuteReplicated(*computations.front(), {sharded_data}, + GetLocalDevices(), execute_options)); XLA_CHECK(sharded_results.size() > 0) << "empty ExecuteReplicated results returned."; XLA_CHECK(sharded_results.size() == 1) @@ -474,8 +474,8 @@ std::vector PjRtComputationClient::ReshardData( torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions execute_options; - auto resharded_results = ExecuteReplicated( - *computation, handles, GetLocalDevices(), execute_options); + auto resharded_results = GetValueOrThrow(ExecuteReplicated( + *computation, handles, GetLocalDevices(), execute_options)); return resharded_results; } @@ -722,7 +722,7 @@ torch::lazy::hash_t PjRtComputationClient::HashCompilationEnv() { return comp_env_hash_; } -std::vector +absl::StatusOr> PjRtComputationClient::ExecuteComputation( const ComputationClient::Computation& computation, absl::Span arguments, @@ -742,14 +742,14 @@ PjRtComputationClient::ExecuteComputation( dynamic_cast(computation); xla::PjRtDevice* pjrt_device = StringToPjRtDevice(device); - XLA_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString(); + ABSL_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString(); std::vector buffers; buffers.reserve(arguments.size()); for (auto& argument : arguments) { const PjRtData* pjrt_data = dynamic_cast(argument.get()); - XLA_CHECK(pjrt_device == pjrt_data->buffer->device()) + ABSL_CHECK(pjrt_device == pjrt_data->buffer->device()) << "The device currently being used : " << pjrt_device->DebugString() << " is different from the device where the buffer resides: " << pjrt_data->buffer->device()->DebugString(); @@ -769,8 +769,9 @@ PjRtComputationClient::ExecuteComputation( << " Done"; std::optional> returned_future; - std::vector> results = - GetValueOrThrow(pjrt_computation.executable->ExecuteSharded( + XLA_ASSIGN_OR_RETURN( + std::vector> results, + pjrt_computation.executable->ExecuteSharded( buffers, pjrt_device, execute_options, returned_future)); returned_future->OnReady(std::move( @@ -795,7 +796,7 @@ PjRtComputationClient::ExecuteComputation( return datas; } -std::vector +absl::StatusOr> PjRtComputationClient::ExecuteReplicated( const ComputationClient::Computation& computation, absl::Span arguments, @@ -829,15 +830,15 @@ PjRtComputationClient::ExecuteReplicated( for (int32_t i = start; i < end; ++i) { auto pjrt_data = std::dynamic_pointer_cast(arguments[i]); - XLA_CHECK_EQ(pjrt_data->shards.size(), devices.size()) + ABSL_CHECK_EQ(pjrt_data->shards.size(), devices.size()) << "Expected one shard per device"; for (int32_t d = 0; d < devices.size(); d++) { std::shared_ptr shard = pjrt_data->shards[d]; xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[d]); - XLA_CHECK_EQ(shard->buffer->device(), pjrt_device); - XLA_CHECK(pjrt_device->IsAddressable()) + ABSL_CHECK_EQ(shard->buffer->device(), pjrt_device); + ABSL_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString(); argument_handles[d][i] = shard->buffer.get(); @@ -873,8 +874,9 @@ PjRtComputationClient::ExecuteReplicated( tsl::profiler::TraceMe activity( "PjRtComputationClient::ExecuteReplicated_execute", tsl::profiler::TraceMeLevel::kInfo); - results = GetValueOrThrow(pjrt_computation.executable->Execute( - std::move(argument_handles), execute_options, returned_futures)); + XLA_ASSIGN_OR_RETURN(results, pjrt_computation.executable->Execute( + std::move(argument_handles), + execute_options, returned_futures)); (*returned_futures)[0].OnReady( std::move([timed, op_tracker = std::move(op_tracker)]( @@ -897,7 +899,7 @@ PjRtComputationClient::ExecuteReplicated( const std::vector& output_shapes = result_shape.IsTuple() ? result_shape.tuple_shapes() : std::vector({result_shape}); - XLA_CHECK_EQ(output_shapes.size(), num_outputs); + ABSL_CHECK_EQ(output_shapes.size(), num_outputs); const std::vector& output_shardings = pjrt_computation.output_shardings_.has_value() && num_outputs > 0 @@ -906,7 +908,7 @@ PjRtComputationClient::ExecuteReplicated( // Without an explicit sharding annotation, the output is implicitly // replicated, and we mark explicitly replicated here. std::vector(num_outputs); - XLA_CHECK_EQ(output_shardings.size(), num_outputs); + ABSL_CHECK_EQ(output_shardings.size(), num_outputs); absl::BlockingCounter counter(num_outputs); diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index 3a6b4478f72..9a93d2864f4 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -85,12 +85,12 @@ class PjRtComputationClient : public ComputationClient { ComputationPtr DeserializeComputation(const std::string& serialized) override; - std::vector ExecuteComputation( + absl::StatusOr> ExecuteComputation( const Computation& computation, absl::Span arguments, const std::string& device, const ExecuteComputationOptions& options) override; - std::vector ExecuteReplicated( + absl::StatusOr> ExecuteReplicated( const Computation& computation, absl::Span arguments, absl::Span devices, const ExecuteReplicatedOptions& options) override; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client_test.cpp b/torch_xla/csrc/runtime/pjrt_computation_client_test.cpp index 0fe2b2a70fc..64496312ae4 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client_test.cpp +++ b/torch_xla/csrc/runtime/pjrt_computation_client_test.cpp @@ -114,9 +114,11 @@ TEST_F(PjRtComputationClientTest, Init) { std::make_shared(std::move(literal_y), device_)}; // Execute the graph. - std::vector results = client_->ExecuteComputation( - *computations[0], client_->TransferToDevice(absl::MakeConstSpan(args)), - device_, options); + std::vector results = + GetValueOrThrow(client_->ExecuteComputation( + *computations[0], + client_->TransferToDevice(absl::MakeConstSpan(args)), device_, + options)); // Copy the output from device back to host and assert correctness. ASSERT_EQ(results.size(), 1); diff --git a/torch_xla/csrc/xla_backend_impl.cpp b/torch_xla/csrc/xla_backend_impl.cpp index df52770b11e..78f8548ff17 100644 --- a/torch_xla/csrc/xla_backend_impl.cpp +++ b/torch_xla/csrc/xla_backend_impl.cpp @@ -163,11 +163,11 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface { torch::lazy::ComputationPtr computation, c10::ArrayRef arguments, const torch::lazy::BackendDevice& device) const override { - std::vector results = + std::vector results = GetValueOrThrow( runtime::GetComputationClientOrDie()->ExecuteComputation( *std::dynamic_pointer_cast( computation), - UnwrapXlaData(arguments), device.toString()); + UnwrapXlaData(arguments), device.toString())); return WrapXlaData(results); } diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 0931578047e..8ea25adcf03 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -845,10 +845,11 @@ XLAGraphExecutor::ExecuteComputationWithBarrier( // tensor results. Both sharded and unsharded results should be // "Assign"ed to the corresponding data placeholders. std::vector outputs = - runtime::GetComputationClientOrDie()->ExecuteReplicated( - *async->cached_computation->computation, - UnwrapXlaData(async->parameters_data), devices, - execute_options); + GetValueOrThrow( + runtime::GetComputationClientOrDie()->ExecuteReplicated( + *async->cached_computation->computation, + UnwrapXlaData(async->parameters_data), devices, + execute_options)); results = WrapXlaData(outputs); TF_VLOG(3) << "Executing Dynamo IR sharded graph hash " << torch::lazy::HashToString(hash) << " on devices " @@ -943,8 +944,8 @@ std::vector XLAGraphExecutor::ExecuteStablehlo( } std::vector result_data = - runtime::GetComputationClientOrDie()->ExecuteComputation( - *computations[0], UnwrapXlaData(arguments), device.toString()); + GetValueOrThrow(runtime::GetComputationClientOrDie()->ExecuteComputation( + *computations[0], UnwrapXlaData(arguments), device.toString())); return WrapXlaData(result_data); } @@ -1120,10 +1121,11 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph( // tensor results. Both sharded and unsharded results should be // "Assign"ed to the corresponding data placeholders. std::vector outputs = - runtime::GetComputationClientOrDie()->ExecuteReplicated( - *async->cached_computation->computation, - UnwrapXlaData(async->parameters_data), devices, - execute_options); + GetValueOrThrow( + runtime::GetComputationClientOrDie()->ExecuteReplicated( + *async->cached_computation->computation, + UnwrapXlaData(async->parameters_data), devices, + execute_options)); results = WrapXlaData(outputs); TORCH_LAZY_COUNTER("ExecuteReplicated", 1); TF_VLOG(3) << "Executing IR graph hash " @@ -1135,11 +1137,13 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph( << torch::lazy::HashToString(hash) << " on device " << async->device << " ..."; std::vector outputs = - runtime::GetComputationClientOrDie()->ExecuteComputation( - *async->cached_computation->computation, - UnwrapXlaData(async->parameters_data), async->device.toString(), - {/*explode_tuple=*/true, - /*eager_mode=*/use_eager_mode}); + GetValueOrThrow( + runtime::GetComputationClientOrDie()->ExecuteComputation( + *async->cached_computation->computation, + UnwrapXlaData(async->parameters_data), + async->device.toString(), + {/*explode_tuple=*/true, + /*eager_mode=*/use_eager_mode})); results = WrapXlaData(outputs); TORCH_LAZY_COUNTER("ExecuteComputation", 1); TF_VLOG(3) << "Executing IR graph hash "