Skip to content

Commit c65ce8f

Browse files
committed
improve error reporting and function naming.
1 parent 60b408d commit c65ce8f

File tree

2 files changed

+22
-22
lines changed

2 files changed

+22
-22
lines changed

torch_xla/csrc/xla_generator.cpp

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,15 @@
1010
#include <c10/util/CallOnce.h>
1111
#include <c10/util/intrusive_ptr.h>
1212

13-
// XLA headers
1413
#include <cstring>
1514
#include <deque>
1615
#include <vector>
1716

17+
#include "absl/status/status.h"
1818
#include "torch_xla/csrc/aten_xla_bridge.h"
1919
#include "torch_xla/csrc/runtime/computation_client.h"
2020
#include "torch_xla/csrc/runtime/runtime.h"
21+
#include "torch_xla/csrc/status.h"
2122

2223
namespace at {
2324

@@ -38,18 +39,12 @@ static std::vector<at::Generator> default_gens_xla;
3839
* Populates the global variables related to XLA generators
3940
* Warning: this function must only be called once!
4041
*/
41-
static void initXLAGenVector() {
42+
static void InitXLAGenVector() {
4243
// Ensures we only call deviceCount only once.
4344
static bool num_xla_device_init_flag [[maybe_unused]] = []() {
4445
// Get local num of XLA devices
45-
auto maybe_client = torch_xla::runtime::GetComputationClient();
46-
if (!maybe_client.ok()) {
47-
// If runtime client initialization failed, default to 1 device
48-
num_xla_devices = 1;
49-
} else {
50-
auto* client = maybe_client.value();
51-
num_xla_devices = static_cast<int64_t>(client->GetNumDevices());
52-
}
46+
XLA_ASSIGN_OR_THROW(auto c_client, torch_xla::runtime::GetComputationClient());
47+
num_xla_devices = static_cast<int64_t>(c_client->GetNumDevices());
5348
xla_gens_init_flag.resize(num_xla_devices);
5449
default_gens_xla.resize(num_xla_devices);
5550
return true;
@@ -63,16 +58,16 @@ static void initXLAGenVector() {
6358
* initialized once. The purpose of these default generators is to
6459
* maintain a global running state of the pseudo random number generation,
6560
* when a user does not explicitly mention any generator.
66-
* getDefaultXLAGenerator gets the default generator for a particular
61+
* GetDefaultXLAGenerator gets the default generator for a particular
6762
* XLA device.
6863
*/
69-
const at::Generator& getDefaultXLAGenerator(c10::DeviceIndex device_index) {
70-
initXLAGenVector();
64+
absl::StatusOr<const at::Generator&> GetDefaultXLAGenerator(c10::DeviceIndex device_index) {
65+
InitXLAGenVector();
7166
c10::DeviceIndex idx = device_index;
7267
if (idx == -1) {
7368
idx = 0; // Default to device 0 for XLA
74-
} else {
75-
TORCH_CHECK(idx >= 0 && idx < num_xla_devices);
69+
} else if (idx < -1 || idx >= num_xla_devices) {
70+
return absl::InvalidArgumentError("Invalid device index for XLA generator. Provided index: " + std::to_string(idx));
7671
}
7772
c10::call_once(xla_gens_init_flag[idx], [&] {
7873
default_gens_xla[idx] = at::make_generator<XLAGeneratorImpl>(idx);
@@ -84,15 +79,16 @@ const at::Generator& getDefaultXLAGenerator(c10::DeviceIndex device_index) {
8479
/**
8580
* Utility to create a XLAGeneratorImpl. Returns a shared_ptr
8681
*/
87-
at::Generator createXLAGenerator(c10::DeviceIndex device_index) {
88-
initXLAGenVector();
82+
absl::StatusOr<at::Generator> CreateXLAGenerator(c10::DeviceIndex device_index) {
83+
InitXLAGenVector();
8984
c10::DeviceIndex idx = device_index;
9085
if (idx == -1) {
9186
idx = torch_xla::bridge::GetCurrentAtenDevice()
9287
.index(); // Use current XLA device
9388
}
94-
TORCH_CHECK(idx >= 0 && idx < num_xla_devices,
95-
"The device_index is invalid.");
89+
else if (idx < -1 || idx >= num_xla_devices) {
90+
return absl::InvalidArgumentError("Invalid device index for XLA generator. Provided index: " + std::to_string(idx));
91+
}
9692
auto gen = at::make_generator<XLAGeneratorImpl>(idx);
9793
auto xla_gen = at::check_generator<XLAGeneratorImpl>(gen);
9894
xla_gen->set_current_seed(c10::default_rng_seed_val);

torch_xla/csrc/xla_generator.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,14 @@
44
#include <ATen/core/Tensor.h>
55
#include <c10/core/Device.h>
66
#include <c10/core/DeviceType.h>
7+
#include <c10/core/GeneratorImpl.h>
8+
#include <c10/core/TensorImpl.h>
79
#include <c10/util/intrusive_ptr.h>
8-
910
#include <cstdint>
1011

12+
#include "absl/status/status.h"
13+
#include "absl/status/statusor.h"
14+
1115
namespace at {
1216

1317
// Holds the actual state variables for the XLA generator.
@@ -57,8 +61,8 @@ struct TORCH_API XLAGeneratorImpl : public c10::GeneratorImpl {
5761

5862
namespace detail {
5963

60-
const at::Generator& getDefaultXLAGenerator(c10::DeviceIndex device_index = -1);
61-
at::Generator createXLAGenerator(c10::DeviceIndex device_index = -1);
64+
absl::StatusOr<const at::Generator&> GetDefaultXLAGenerator(c10::DeviceIndex device_index = -1);
65+
absl::StatusOr<at::Generator> CreateXLAGenerator(c10::DeviceIndex device_index = -1);
6266

6367
} // namespace detail
6468

0 commit comments

Comments
 (0)