Skip to content

Commit e89e659

Browse files
committed
Add unit tests for GetDefaultXLAGenerator and CreateXLAGenerator
1 parent e745336 commit e89e659

File tree

1 file changed

+118
-1
lines changed

1 file changed

+118
-1
lines changed

test/cpp/test_xla_generator.cpp

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include <gtest/gtest.h>
22
#include <torch/torch.h>
33

4+
#include <cstdlib>
5+
46
#include "test/cpp/torch_xla_test.h"
57
#include "torch_xla/csrc/xla_generator.h"
68

@@ -18,6 +20,20 @@ class XLAGeneratorTest : public ::torch_xla::cpp_test::TorchXlaTest {
1820
at::Generator gen_;
1921
};
2022

23+
// Ensure PJRT is configured to a CPU backend for tests that touch the PJRT
24+
// runtime.
25+
static void EnsurePjrtCpuBackend() {
26+
const char* pjrt = std::getenv("PJRT_DEVICE");
27+
if (pjrt == nullptr || pjrt[0] == '\0') {
28+
// Use CPU backend with a single device by default.
29+
setenv("PJRT_DEVICE", "CPU", 1);
30+
}
31+
const char* cpu_devices = std::getenv("CPU_NUM_DEVICES");
32+
if (cpu_devices == nullptr || cpu_devices[0] == '\0') {
33+
setenv("CPU_NUM_DEVICES", "1", 0);
34+
}
35+
}
36+
2137
TEST_F(XLAGeneratorTest, Constructor) {
2238
// Check that the generator was created for the correct device
2339
ASSERT_EQ(gen_.device().type(), at::DeviceType::XLA);
@@ -102,5 +118,106 @@ TEST_F(XLAGeneratorTest, Clone) {
102118
ASSERT_NE(cloned_gen.current_seed(), gen_.current_seed());
103119
}
104120

121+
TEST_F(XLAGeneratorTest, GetDefaultXLAGenerator) {
122+
EnsurePjrtCpuBackend();
123+
// Test getting default generator for device 0
124+
auto result = at::detail::GetDefaultXLAGenerator(0);
125+
ASSERT_TRUE(result.ok()) << "Failed to get default generator: "
126+
<< result.status();
127+
128+
const at::Generator& default_gen = result.value();
129+
ASSERT_EQ(default_gen.device().type(), at::DeviceType::XLA);
130+
ASSERT_EQ(default_gen.device().index(), 0);
131+
132+
// Test getting default generator with -1 (should default to device 0)
133+
auto result_default = at::detail::GetDefaultXLAGenerator(-1);
134+
ASSERT_TRUE(result_default.ok())
135+
<< "Failed to get default generator with -1: " << result_default.status();
136+
137+
const at::Generator& default_gen_neg1 = result_default.value();
138+
ASSERT_EQ(default_gen_neg1.device().type(), at::DeviceType::XLA);
139+
ASSERT_EQ(default_gen_neg1.device().index(), 0);
140+
141+
// Test that subsequent calls return the same generator instance
142+
auto result2 = at::detail::GetDefaultXLAGenerator(0);
143+
ASSERT_TRUE(result2.ok());
144+
const at::Generator& default_gen2 = result2.value();
145+
ASSERT_EQ(std::addressof(default_gen), std::addressof(default_gen2));
146+
}
147+
148+
TEST_F(XLAGeneratorTest, GetDefaultXLAGeneratorInvalidDevice) {
149+
EnsurePjrtCpuBackend();
150+
// Test with invalid device indices
151+
auto result_neg2 = at::detail::GetDefaultXLAGenerator(-2);
152+
ASSERT_FALSE(result_neg2.ok());
153+
ASSERT_TRUE(absl::IsInvalidArgument(result_neg2.status()));
154+
155+
// Test with very large device index (assuming there aren't 1000 XLA devices)
156+
auto result_large = at::detail::GetDefaultXLAGenerator(1000);
157+
ASSERT_FALSE(result_large.ok());
158+
ASSERT_TRUE(absl::IsInvalidArgument(result_large.status()));
159+
}
160+
161+
TEST_F(XLAGeneratorTest, CreateXLAGenerator) {
162+
EnsurePjrtCpuBackend();
163+
// Test creating generator for device 0
164+
auto result = at::detail::CreateXLAGenerator(0);
165+
ASSERT_TRUE(result.ok()) << "Failed to create generator: " << result.status();
166+
167+
at::Generator created_gen = result.value();
168+
ASSERT_EQ(created_gen.device().type(), at::DeviceType::XLA);
169+
ASSERT_EQ(created_gen.device().index(), 0);
170+
171+
// Test that the generator is initialized with default seed
172+
ASSERT_EQ(created_gen.current_seed(), c10::default_rng_seed_val);
173+
174+
// Test creating generator with -1 (should use current device)
175+
auto result_default = at::detail::CreateXLAGenerator(-1);
176+
ASSERT_TRUE(result_default.ok())
177+
<< "Failed to create generator with -1: " << result_default.status();
178+
179+
at::Generator created_gen_neg1 = result_default.value();
180+
ASSERT_EQ(created_gen_neg1.device().type(), at::DeviceType::XLA);
181+
// Device index should be >= 0 (actual device depends on current XLA device)
182+
ASSERT_GE(created_gen_neg1.device().index(), 0);
183+
}
184+
185+
TEST_F(XLAGeneratorTest, CreateXLAGeneratorUniqueness) {
186+
EnsurePjrtCpuBackend();
187+
// Test that each call creates a new generator instance
188+
auto result1 = at::detail::CreateXLAGenerator(0);
189+
auto result2 = at::detail::CreateXLAGenerator(0);
190+
191+
ASSERT_TRUE(result1.ok());
192+
ASSERT_TRUE(result2.ok());
193+
194+
at::Generator gen1 = result1.value();
195+
at::Generator gen2 = result2.value();
196+
197+
// Should be different instances
198+
ASSERT_NE(std::addressof(gen1), std::addressof(gen2));
199+
200+
// But should have same device and initial seed
201+
ASSERT_EQ(gen1.device(), gen2.device());
202+
ASSERT_EQ(gen1.current_seed(), gen2.current_seed());
203+
204+
// Modifying one should not affect the other
205+
gen1.set_current_seed(12345);
206+
ASSERT_NE(gen1.current_seed(), gen2.current_seed());
207+
}
208+
209+
TEST_F(XLAGeneratorTest, CreateXLAGeneratorInvalidDevice) {
210+
EnsurePjrtCpuBackend();
211+
// Test with invalid device indices
212+
auto result_neg2 = at::detail::CreateXLAGenerator(-2);
213+
ASSERT_FALSE(result_neg2.ok());
214+
ASSERT_TRUE(absl::IsInvalidArgument(result_neg2.status()));
215+
216+
// Test with very large device index (assuming there aren't 1000 XLA devices)
217+
auto result_large = at::detail::CreateXLAGenerator(1000);
218+
ASSERT_FALSE(result_large.ok());
219+
ASSERT_TRUE(absl::IsInvalidArgument(result_large.status()));
220+
}
221+
105222
} // namespace cpp_test
106-
} // namespace torch_xla
223+
} // namespace torch_xla

0 commit comments

Comments
 (0)