Skip to content

Commit 914d708

Browse files
committed
format
1 parent 79d4b42 commit 914d708

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

torch_xla/csrc/xla_generator.cpp

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,19 @@
55
#include <ATen/core/Tensor.h>
66
#include <c10/core/Device.h>
77
#include <c10/core/DeviceType.h>
8-
#include <c10/core/TensorImpl.h>
98
#include <c10/core/GeneratorImpl.h>
10-
#include <c10/util/intrusive_ptr.h>
9+
#include <c10/core/TensorImpl.h>
1110
#include <c10/util/CallOnce.h>
11+
#include <c10/util/intrusive_ptr.h>
1212

1313
// XLA headers
14-
#include "torch_xla/csrc/runtime/computation_client.h"
15-
#include "torch_xla/csrc/aten_xla_bridge.h"
16-
1714
#include <cstring>
1815
#include <deque>
1916
#include <vector>
2017

18+
#include "torch_xla/csrc/aten_xla_bridge.h"
19+
#include "torch_xla/csrc/runtime/computation_client.h"
20+
2121
namespace at {
2222

2323
namespace detail {
@@ -55,7 +55,7 @@ static void initXLAGenVector() {
5555
}();
5656
}
5757

58-
} // anonymous namespace
58+
} // anonymous namespace
5959

6060
/**
6161
* PyTorch maintains a collection of default generators that get
@@ -69,7 +69,7 @@ const at::Generator& getDefaultXLAGenerator(c10::DeviceIndex device_index) {
6969
initXLAGenVector();
7070
c10::DeviceIndex idx = device_index;
7171
if (idx == -1) {
72-
idx = 0; // Default to device 0 for XLA
72+
idx = 0; // Default to device 0 for XLA
7373
} else {
7474
TORCH_CHECK(idx >= 0 && idx < num_xla_devices);
7575
}
@@ -87,17 +87,19 @@ at::Generator createXLAGenerator(c10::DeviceIndex device_index) {
8787
initXLAGenVector();
8888
c10::DeviceIndex idx = device_index;
8989
if (idx == -1) {
90-
idx = torch_xla::bridge::GetCurrentAtenDevice().index(); // Use current XLA device
90+
idx = torch_xla::bridge::GetCurrentAtenDevice()
91+
.index(); // Use current XLA device
9192
}
92-
TORCH_CHECK(idx >= 0 && idx < num_xla_devices, "The device_index is invalid.");
93+
TORCH_CHECK(idx >= 0 && idx < num_xla_devices,
94+
"The device_index is invalid.");
9395
auto gen = at::make_generator<XLAGeneratorImpl>(idx);
9496
auto xla_gen = at::check_generator<XLAGeneratorImpl>(gen);
9597
xla_gen->set_current_seed(c10::default_rng_seed_val);
9698
return gen;
9799
}
98100

99-
} // namespace detail
100-
} // namespace at
101+
} // namespace detail
102+
} // namespace at
101103

102104
namespace at {
103105

torch_xla/csrc/xla_generator.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,6 @@ namespace detail {
6060
const at::Generator& getDefaultXLAGenerator(c10::DeviceIndex device_index = -1);
6161
at::Generator createXLAGenerator(c10::DeviceIndex device_index = -1);
6262

63-
} // namespace detail
63+
} // namespace detail
6464

6565
} // namespace at

0 commit comments

Comments
 (0)