Skip to content

Commit 79d4b42

Browse files
committed
implement XLAHooks and register it to PyTorch when loaded.
1 parent 79ff99b commit 79d4b42

File tree

3 files changed

+158
-0
lines changed

3 files changed

+158
-0
lines changed

torch_xla/csrc/BUILD

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ ptxla_cc_library(
270270
":status",
271271
":tensor",
272272
":version",
273+
":xla_hooks",
273274
"//torch_xla/csrc/runtime",
274275
"//torch_xla/csrc/runtime:pjrt_computation_client",
275276
"//torch_xla/csrc/runtime:metrics",
@@ -374,3 +375,21 @@ cc_library(
374375
"@com_google_absl//absl/status:statusor",
375376
],
376377
)
378+
379+
ptxla_cc_library(
380+
name = "xla_hooks",
381+
srcs = [
382+
"xla_hooks.cpp",
383+
],
384+
hdrs = [
385+
"xla_hooks.h",
386+
],
387+
deps = [
388+
"//torch_xla/csrc:device",
389+
"//torch_xla/csrc:tensor",
390+
"//torch_xla/csrc/runtime:computation_client",
391+
"//torch_xla/csrc/runtime",
392+
"//torch_xla/csrc/runtime:xla_util",
393+
],
394+
)
395+

torch_xla/csrc/xla_hooks.cpp

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
#include "xla_hooks.h"
2+
3+
#include <sstream>
4+
#include <iostream>
5+
6+
// PyTorch integration headers
7+
#include <ATen/core/Generator.h>
8+
#include <ATen/detail/XLAHooksInterface.h>
9+
#include <c10/util/Exception.h>
10+
#include <c10/core/Device.h>
11+
#include <c10/core/DeviceType.h>
12+
#include <c10/util/intrusive_ptr.h>
13+
#include <c10/util/Logging.h>
14+
15+
// XLA headers
16+
#include "xla_generator.h"
17+
#include "xla_backend_impl.h"
18+
#include "torch_xla/csrc/aten_xla_bridge.h"
19+
#include "torch_xla/csrc/runtime/computation_client.h"
20+
#include "torch_xla/csrc/runtime/debug_macros.h"
21+
#include "torch_xla/csrc/runtime/runtime.h"
22+
23+
24+
namespace torch_xla::detail {
25+
26+
void XLAHooks::init() const {
27+
C10_LOG_API_USAGE_ONCE("aten.init.xla");
28+
29+
// Initialize XLA backend - this registers XLA functions and sets up
30+
// the backend infrastructure
31+
torch_xla::InitXlaBackend();
32+
}
33+
34+
bool XLAHooks::hasXLA() const {
35+
return isAvailable();
36+
}
37+
38+
bool XLAHooks::isAvailable() const {
39+
try {
40+
return deviceCount() > 0;
41+
} catch (...) {
42+
// If device enumeration fails, XLA is not available
43+
return false;
44+
}
45+
}
46+
47+
std::string XLAHooks::showConfig() const {
48+
std::ostringstream oss;
49+
oss << "XLA Backend Configuration:\n";
50+
oss << " - XLA devices available: " << deviceCount() << "\n";
51+
return oss.str();
52+
}
53+
54+
c10::DeviceIndex XLAHooks::deviceCount() const {
55+
auto maybe_client = torch_xla::runtime::GetComputationClient();
56+
if (!maybe_client.ok()) {
57+
// If runtime client initialization failed, return 0 devices
58+
return 0;
59+
}
60+
61+
auto* client = maybe_client.value();
62+
return static_cast<c10::DeviceIndex>(client->GetNumDevices());
63+
}
64+
65+
c10::DeviceIndex XLAHooks::getCurrentDevice() const {
66+
return bridge::GetCurrentAtenDevice().index();
67+
}
68+
69+
bool XLAHooks::hasPrimaryContext(c10::DeviceIndex device_index) const {
70+
TORCH_CHECK(false, "hasPrimaryContext is not implemented.");
71+
}
72+
73+
bool XLAHooks::isPinnedPtr(const void* data) const {
74+
TORCH_CHECK(false, "isPinnedPtr is not implemented.");
75+
}
76+
77+
c10::Allocator* XLAHooks::getPinnedMemoryAllocator() const {
78+
TORCH_CHECK(false, "getPinnedMemoryAllocator is not implemented.");
79+
}
80+
81+
c10::Device XLAHooks::getDeviceFromPtr(void* data) const {
82+
TORCH_CHECK(false, "getDeviceFromPtr is not implemented.");
83+
}
84+
85+
const at::Generator& XLAHooks::getDefaultGenerator(c10::DeviceIndex device_index) const {
86+
return at::detail::getDefaultXLAGenerator(device_index);
87+
}
88+
89+
at::Generator XLAHooks::getNewGenerator(c10::DeviceIndex device_index) const {
90+
// Create and return a new XLA generator using the make_generator template function
91+
return at::make_generator<at::XLAGeneratorImpl>(device_index);
92+
}
93+
94+
} // namespace torch_xla::detail
95+
96+
// Register XLA hooks with PyTorch on module load
97+
namespace at {
98+
REGISTER_XLA_HOOKS(torch_xla::detail::XLAHooks)
99+
} // namespace at

torch_xla/csrc/xla_hooks.h

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#pragma once
2+
3+
#include <string>
4+
5+
// PyTorch integration headers
6+
#include <ATen/detail/XLAHooksInterface.h>
7+
#include <c10/core/DeviceType.h>
8+
#include <c10/core/Device.h>
9+
#include <c10/util/Exception.h>
10+
#include <ATen/core/Generator.h>
11+
12+
namespace torch_xla::detail {
13+
14+
// XLA hooks implementation following PyTorch patterns
15+
struct XLAHooks : public at::XLAHooksInterface {
16+
XLAHooks(const at::XLAHooksArgs& args) {}
17+
18+
// Core accelerator interface methods
19+
void init() const override;
20+
bool hasXLA() const override;
21+
bool isAvailable() const override;
22+
bool isBuilt() const override { return true; }
23+
std::string showConfig() const override;
24+
25+
// Device management
26+
c10::DeviceIndex deviceCount() const override;
27+
c10::DeviceIndex getCurrentDevice() const override;
28+
bool hasPrimaryContext(c10::DeviceIndex device_index) const override;
29+
30+
// Memory management
31+
bool isPinnedPtr(const void* data) const override;
32+
c10::Allocator* getPinnedMemoryAllocator() const override;
33+
c10::Device getDeviceFromPtr(void* data) const override;
34+
35+
// Generator methods
36+
const at::Generator& getDefaultGenerator(c10::DeviceIndex device_index = -1) const override;
37+
at::Generator getNewGenerator(c10::DeviceIndex device_index = -1) const override;
38+
};
39+
40+
} // namespace torch_xla::detail

0 commit comments

Comments
 (0)