Skip to content

Commit 79ff99b

Browse files
committed
Add helper functions getDefaultXLAGenerator and createXLAGenerator to XLA random number generator
1 parent fa49099 commit 79ff99b

File tree

2 files changed

+99
-1
lines changed

2 files changed

+99
-1
lines changed

torch_xla/csrc/xla_generator.cpp

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,98 @@
66
#include <c10/core/Device.h>
77
#include <c10/core/DeviceType.h>
88
#include <c10/core/TensorImpl.h>
9+
#include <c10/core/GeneratorImpl.h>
910
#include <c10/util/intrusive_ptr.h>
11+
#include <c10/util/CallOnce.h>
12+
13+
// XLA headers
14+
#include "torch_xla/csrc/runtime/computation_client.h"
15+
#include "torch_xla/csrc/aten_xla_bridge.h"
1016

1117
#include <cstring>
18+
#include <deque>
19+
#include <vector>
20+
21+
namespace at {
22+
23+
namespace detail {
24+
25+
namespace {
26+
27+
// Total number of XLA devices in the system.
28+
static int64_t num_xla_devices;
29+
30+
// Ensures default_gens_xla is initialized once.
31+
static std::deque<c10::once_flag> xla_gens_init_flag;
32+
33+
// Default, global XLA generators, one per XLA device.
34+
static std::vector<at::Generator> default_gens_xla;
35+
36+
/*
37+
* Populates the global variables related to XLA generators
38+
* Warning: this function must only be called once!
39+
*/
40+
static void initXLAGenVector() {
41+
// Ensures we only call deviceCount only once.
42+
static bool num_xla_device_init_flag [[maybe_unused]] = []() {
43+
// Get local num of XLA devices
44+
auto maybe_client = torch_xla::runtime::GetComputationClient();
45+
if (!maybe_client.ok()) {
46+
// If runtime client initialization failed, default to 1 device
47+
num_xla_devices = 1;
48+
} else {
49+
auto* client = maybe_client.value();
50+
num_xla_devices = static_cast<int64_t>(client->GetNumDevices());
51+
}
52+
xla_gens_init_flag.resize(num_xla_devices);
53+
default_gens_xla.resize(num_xla_devices);
54+
return true;
55+
}();
56+
}
57+
58+
} // anonymous namespace
59+
60+
/**
61+
* PyTorch maintains a collection of default generators that get
62+
* initialized once. The purpose of these default generators is to
63+
* maintain a global running state of the pseudo random number generation,
64+
* when a user does not explicitly mention any generator.
65+
* getDefaultXLAGenerator gets the default generator for a particular
66+
* XLA device.
67+
*/
68+
const at::Generator& getDefaultXLAGenerator(c10::DeviceIndex device_index) {
69+
initXLAGenVector();
70+
c10::DeviceIndex idx = device_index;
71+
if (idx == -1) {
72+
idx = 0; // Default to device 0 for XLA
73+
} else {
74+
TORCH_CHECK(idx >= 0 && idx < num_xla_devices);
75+
}
76+
c10::call_once(xla_gens_init_flag[idx], [&] {
77+
default_gens_xla[idx] = at::make_generator<XLAGeneratorImpl>(idx);
78+
default_gens_xla[idx].seed();
79+
});
80+
return default_gens_xla[idx];
81+
}
82+
83+
/**
84+
* Utility to create a XLAGeneratorImpl. Returns a shared_ptr
85+
*/
86+
at::Generator createXLAGenerator(c10::DeviceIndex device_index) {
87+
initXLAGenVector();
88+
c10::DeviceIndex idx = device_index;
89+
if (idx == -1) {
90+
idx = torch_xla::bridge::GetCurrentAtenDevice().index(); // Use current XLA device
91+
}
92+
TORCH_CHECK(idx >= 0 && idx < num_xla_devices, "The device_index is invalid.");
93+
auto gen = at::make_generator<XLAGeneratorImpl>(idx);
94+
auto xla_gen = at::check_generator<XLAGeneratorImpl>(gen);
95+
xla_gen->set_current_seed(c10::default_rng_seed_val);
96+
return gen;
97+
}
98+
99+
} // namespace detail
100+
} // namespace at
12101

13102
namespace at {
14103

torch_xla/csrc/xla_generator.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
#include <ATen/core/Generator.h>
44
#include <ATen/core/Tensor.h>
5+
#include <c10/core/Device.h>
6+
#include <c10/core/DeviceType.h>
57
#include <c10/util/intrusive_ptr.h>
68

79
#include <cstdint>
@@ -53,4 +55,11 @@ struct TORCH_API XLAGeneratorImpl : public c10::GeneratorImpl {
5355
c10::intrusive_ptr<XLAGeneratorState> state_;
5456
};
5557

56-
} // namespace at
58+
namespace detail {
59+
60+
const at::Generator& getDefaultXLAGenerator(c10::DeviceIndex device_index = -1);
61+
at::Generator createXLAGenerator(c10::DeviceIndex device_index = -1);
62+
63+
} // namespace detail
64+
65+
} // namespace at

0 commit comments

Comments
 (0)