11#include < gtest/gtest.h>
22#include < torch/torch.h>
33
4+ #include < cstdlib>
5+
46#include " test/cpp/torch_xla_test.h"
57#include " torch_xla/csrc/xla_generator.h"
68
@@ -18,6 +20,20 @@ class XLAGeneratorTest : public ::torch_xla::cpp_test::TorchXlaTest {
1820 at::Generator gen_;
1921};
2022
23+ // Ensure PJRT is configured to a CPU backend for tests that touch the PJRT
24+ // runtime.
25+ static void EnsurePjrtCpuBackend () {
26+ const char * pjrt = std::getenv (" PJRT_DEVICE" );
27+ if (pjrt == nullptr || pjrt[0 ] == ' \0 ' ) {
28+ // Use CPU backend with a single device by default.
29+ setenv (" PJRT_DEVICE" , " CPU" , 1 );
30+ }
31+ const char * cpu_devices = std::getenv (" CPU_NUM_DEVICES" );
32+ if (cpu_devices == nullptr || cpu_devices[0 ] == ' \0 ' ) {
33+ setenv (" CPU_NUM_DEVICES" , " 1" , 0 );
34+ }
35+ }
36+
2137TEST_F (XLAGeneratorTest, Constructor) {
2238 // Check that the generator was created for the correct device
2339 ASSERT_EQ (gen_.device ().type (), at::DeviceType::XLA);
@@ -102,5 +118,106 @@ TEST_F(XLAGeneratorTest, Clone) {
102118 ASSERT_NE (cloned_gen.current_seed (), gen_.current_seed ());
103119}
104120
121+ TEST_F (XLAGeneratorTest, GetDefaultXLAGenerator) {
122+ EnsurePjrtCpuBackend ();
123+ // Test getting default generator for device 0
124+ auto result = at::detail::GetDefaultXLAGenerator (0 );
125+ ASSERT_TRUE (result.ok ()) << " Failed to get default generator: "
126+ << result.status ();
127+
128+ const at::Generator& default_gen = result.value ();
129+ ASSERT_EQ (default_gen.device ().type (), at::DeviceType::XLA);
130+ ASSERT_EQ (default_gen.device ().index (), 0 );
131+
132+ // Test getting default generator with -1 (should default to device 0)
133+ auto result_default = at::detail::GetDefaultXLAGenerator (-1 );
134+ ASSERT_TRUE (result_default.ok ())
135+ << " Failed to get default generator with -1: " << result_default.status ();
136+
137+ const at::Generator& default_gen_neg1 = result_default.value ();
138+ ASSERT_EQ (default_gen_neg1.device ().type (), at::DeviceType::XLA);
139+ ASSERT_EQ (default_gen_neg1.device ().index (), 0 );
140+
141+ // Test that subsequent calls return the same generator instance
142+ auto result2 = at::detail::GetDefaultXLAGenerator (0 );
143+ ASSERT_TRUE (result2.ok ());
144+ const at::Generator& default_gen2 = result2.value ();
145+ ASSERT_EQ (std::addressof (default_gen), std::addressof (default_gen2));
146+ }
147+
148+ TEST_F (XLAGeneratorTest, GetDefaultXLAGeneratorInvalidDevice) {
149+ EnsurePjrtCpuBackend ();
150+ // Test with invalid device indices
151+ auto result_neg2 = at::detail::GetDefaultXLAGenerator (-2 );
152+ ASSERT_FALSE (result_neg2.ok ());
153+ ASSERT_TRUE (absl::IsInvalidArgument (result_neg2.status ()));
154+
155+ // Test with very large device index (assuming there aren't 1000 XLA devices)
156+ auto result_large = at::detail::GetDefaultXLAGenerator (1000 );
157+ ASSERT_FALSE (result_large.ok ());
158+ ASSERT_TRUE (absl::IsInvalidArgument (result_large.status ()));
159+ }
160+
161+ TEST_F (XLAGeneratorTest, CreateXLAGenerator) {
162+ EnsurePjrtCpuBackend ();
163+ // Test creating generator for device 0
164+ auto result = at::detail::CreateXLAGenerator (0 );
165+ ASSERT_TRUE (result.ok ()) << " Failed to create generator: " << result.status ();
166+
167+ at::Generator created_gen = result.value ();
168+ ASSERT_EQ (created_gen.device ().type (), at::DeviceType::XLA);
169+ ASSERT_EQ (created_gen.device ().index (), 0 );
170+
171+ // Test that the generator is initialized with default seed
172+ ASSERT_EQ (created_gen.current_seed (), c10::default_rng_seed_val);
173+
174+ // Test creating generator with -1 (should use current device)
175+ auto result_default = at::detail::CreateXLAGenerator (-1 );
176+ ASSERT_TRUE (result_default.ok ())
177+ << " Failed to create generator with -1: " << result_default.status ();
178+
179+ at::Generator created_gen_neg1 = result_default.value ();
180+ ASSERT_EQ (created_gen_neg1.device ().type (), at::DeviceType::XLA);
181+ // Device index should be >= 0 (actual device depends on current XLA device)
182+ ASSERT_GE (created_gen_neg1.device ().index (), 0 );
183+ }
184+
185+ TEST_F (XLAGeneratorTest, CreateXLAGeneratorUniqueness) {
186+ EnsurePjrtCpuBackend ();
187+ // Test that each call creates a new generator instance
188+ auto result1 = at::detail::CreateXLAGenerator (0 );
189+ auto result2 = at::detail::CreateXLAGenerator (0 );
190+
191+ ASSERT_TRUE (result1.ok ());
192+ ASSERT_TRUE (result2.ok ());
193+
194+ at::Generator gen1 = result1.value ();
195+ at::Generator gen2 = result2.value ();
196+
197+ // Should be different instances
198+ ASSERT_NE (std::addressof (gen1), std::addressof (gen2));
199+
200+ // But should have same device and initial seed
201+ ASSERT_EQ (gen1.device (), gen2.device ());
202+ ASSERT_EQ (gen1.current_seed (), gen2.current_seed ());
203+
204+ // Modifying one should not affect the other
205+ gen1.set_current_seed (12345 );
206+ ASSERT_NE (gen1.current_seed (), gen2.current_seed ());
207+ }
208+
209+ TEST_F (XLAGeneratorTest, CreateXLAGeneratorInvalidDevice) {
210+ EnsurePjrtCpuBackend ();
211+ // Test with invalid device indices
212+ auto result_neg2 = at::detail::CreateXLAGenerator (-2 );
213+ ASSERT_FALSE (result_neg2.ok ());
214+ ASSERT_TRUE (absl::IsInvalidArgument (result_neg2.status ()));
215+
216+ // Test with very large device index (assuming there aren't 1000 XLA devices)
217+ auto result_large = at::detail::CreateXLAGenerator (1000 );
218+ ASSERT_FALSE (result_large.ok ());
219+ ASSERT_TRUE (absl::IsInvalidArgument (result_large.status ()));
220+ }
221+
105222} // namespace cpp_test
106- } // namespace torch_xla
223+ } // namespace torch_xla
0 commit comments