1010#include < c10/util/CallOnce.h>
1111#include < c10/util/intrusive_ptr.h>
1212
13- // XLA headers
1413#include < cstring>
1514#include < deque>
1615#include < vector>
1716
17+ #include " absl/status/status.h"
1818#include " torch_xla/csrc/aten_xla_bridge.h"
1919#include " torch_xla/csrc/runtime/computation_client.h"
2020#include " torch_xla/csrc/runtime/runtime.h"
21+ #include " torch_xla/csrc/status.h"
2122
2223namespace at {
2324
@@ -38,18 +39,12 @@ static std::vector<at::Generator> default_gens_xla;
3839 * Populates the global variables related to XLA generators
3940 * Warning: this function must only be called once!
4041 */
41- static void initXLAGenVector () {
42+ static void InitXLAGenVector () {
4243 // Ensures we only call deviceCount only once.
4344 static bool num_xla_device_init_flag [[maybe_unused]] = []() {
4445 // Get local num of XLA devices
45- auto maybe_client = torch_xla::runtime::GetComputationClient ();
46- if (!maybe_client.ok ()) {
47- // If runtime client initialization failed, default to 1 device
48- num_xla_devices = 1 ;
49- } else {
50- auto * client = maybe_client.value ();
51- num_xla_devices = static_cast <int64_t >(client->GetNumDevices ());
52- }
46+ XLA_ASSIGN_OR_THROW (auto c_client, torch_xla::runtime::GetComputationClient ());
47+ num_xla_devices = static_cast <int64_t >(c_client->GetNumDevices ());
5348 xla_gens_init_flag.resize (num_xla_devices);
5449 default_gens_xla.resize (num_xla_devices);
5550 return true ;
@@ -63,16 +58,16 @@ static void initXLAGenVector() {
6358 * initialized once. The purpose of these default generators is to
6459 * maintain a global running state of the pseudo random number generation,
6560 * when a user does not explicitly mention any generator.
66- * getDefaultXLAGenerator gets the default generator for a particular
61+ * GetDefaultXLAGenerator gets the default generator for a particular
6762 * XLA device.
6863 */
69- const at::Generator& getDefaultXLAGenerator (c10::DeviceIndex device_index) {
70- initXLAGenVector ();
64+ absl::StatusOr< const at::Generator&> GetDefaultXLAGenerator (c10::DeviceIndex device_index) {
65+ InitXLAGenVector ();
7166 c10::DeviceIndex idx = device_index;
7267 if (idx == -1 ) {
7368 idx = 0 ; // Default to device 0 for XLA
74- } else {
75- TORCH_CHECK (idx >= 0 && idx < num_xla_devices );
69+ } else if (idx < - 1 || idx >= num_xla_devices) {
70+ return absl::InvalidArgumentError ( " Invalid device index for XLA generator. Provided index: " + std::to_string (idx) );
7671 }
7772 c10::call_once (xla_gens_init_flag[idx], [&] {
7873 default_gens_xla[idx] = at::make_generator<XLAGeneratorImpl>(idx);
@@ -84,15 +79,16 @@ const at::Generator& getDefaultXLAGenerator(c10::DeviceIndex device_index) {
8479/* *
8580 * Utility to create a XLAGeneratorImpl. Returns a shared_ptr
8681 */
87- at::Generator createXLAGenerator (c10::DeviceIndex device_index) {
88- initXLAGenVector ();
82+ absl::StatusOr< at::Generator> CreateXLAGenerator (c10::DeviceIndex device_index) {
83+ InitXLAGenVector ();
8984 c10::DeviceIndex idx = device_index;
9085 if (idx == -1 ) {
9186 idx = torch_xla::bridge::GetCurrentAtenDevice ()
9287 .index (); // Use current XLA device
9388 }
94- TORCH_CHECK (idx >= 0 && idx < num_xla_devices,
95- " The device_index is invalid." );
89+ else if (idx < -1 || idx >= num_xla_devices) {
90+ return absl::InvalidArgumentError (" Invalid device index for XLA generator. Provided index: " + std::to_string (idx));
91+ }
9692 auto gen = at::make_generator<XLAGeneratorImpl>(idx);
9793 auto xla_gen = at::check_generator<XLAGeneratorImpl>(gen);
9894 xla_gen->set_current_seed (c10::default_rng_seed_val);
0 commit comments