Skip to content

Commit d3dd3f9

Browse files
authored
Add GPU detection logic (#12)
There are a few pieces to GPU detection for supporting Mojo. The final goal is to have a valid value we can pass as `--target-accelerator=nvidia:80` to `mojo build`. In order to do this based on the target platform we have to detect the current GPU with nvidia-smi or rocm-smi, parse the output, and setup various config_settings to determine which toolchain should be used. If you use `--platforms=@mojo_host_toolchain` you get the rest of this logic for free. The currently known supported GPUs are seeded in the `mojo.gpu_toolchains` module extension, but new ones can be added in individual projects as well.
1 parent 433e86c commit d3dd3f9

File tree

9 files changed

+385
-14
lines changed

9 files changed

+385
-14
lines changed

.bazelrc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,8 @@
11
common --incompatible_strict_action_env
22
common --test_output=errors
3+
4+
# https://github.com/bazelbuild/bazel/issues/25145
5+
info --platforms=
6+
7+
common --platforms=@mojo_host_platform
8+
common --host_platform=@mojo_host_platform

BUILD.bazel

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,8 @@ toolchain_type(
1010
name = "toolchain_type",
1111
visibility = ["//visibility:public"],
1212
)
13+
14+
toolchain_type(
15+
name = "gpu_toolchain_type",
16+
visibility = ["//visibility:public"],
17+
)

MODULE.bazel

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@ bazel_dep(name = "rules_python", version = "1.0.0")
1111

1212
mojo = use_extension("//mojo:extensions.bzl", "mojo")
1313
mojo.toolchain()
14-
use_repo(mojo, "mojo_toolchains")
14+
mojo.gpu_toolchains()
15+
use_repo(mojo, "mojo_gpu_toolchains", "mojo_host_platform", "mojo_toolchains")
1516

16-
register_toolchains("@mojo_toolchains//...")
17+
register_toolchains("@mojo_toolchains//...", "@mojo_gpu_toolchains//...")
1718

1819
_DEFAULT_PYTHON_VERSION = "3.12"
1920

mojo/extensions.bzl

Lines changed: 77 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
"""MODULE.bazel extensions for Mojo toolchains."""
22

3+
load("//mojo:mojo_host_platform.bzl", "mojo_host_platform")
4+
load("//mojo/private:mojo_gpu_toolchains_repository.bzl", "mojo_gpu_toolchains_repository")
5+
36
_PLATFORMS = ["linux_aarch64", "linux_x86_64", "macos_arm64"]
47
_DEFAULT_VERSION = "25.4.0.dev2025050902"
58
_KNOWN_SHAS = {
@@ -98,20 +101,36 @@ def _mojo_impl(mctx):
98101
if not module.is_root:
99102
continue
100103

104+
toolchains = module.tags.toolchain
101105
if len(module.tags.toolchain) > 1:
102106
fail("mojo.toolchain() can only be called once per module.")
107+
if toolchains:
108+
has_toolchains = True
109+
tags = toolchains[0]
110+
111+
for platform in _PLATFORMS:
112+
name = "mojo_toolchain_{}".format(platform)
113+
_mojo_toolchain_repository(
114+
name = name,
115+
version = tags.version,
116+
platform = platform,
117+
url_override = tags.url_override,
118+
use_prebuilt_packages = tags.use_prebuilt_packages,
119+
)
103120

104-
has_toolchains = True
105-
tags = module.tags.toolchain[0]
106-
107-
for platform in _PLATFORMS:
108-
name = "mojo_toolchain_{}".format(platform)
109-
_mojo_toolchain_repository(
110-
name = name,
111-
version = tags.version,
112-
platform = platform,
113-
url_override = tags.url_override,
114-
use_prebuilt_packages = tags.use_prebuilt_packages,
121+
gpu_toolchains = module.tags.gpu_toolchains
122+
if len(gpu_toolchains) > 1:
123+
fail("mojo.gpu_toolchain() can only be called once per module.")
124+
if gpu_toolchains:
125+
gpu_toolchain = gpu_toolchains[0]
126+
mojo_gpu_toolchains_repository(
127+
name = "mojo_gpu_toolchains",
128+
supported_gpus = gpu_toolchain.supported_gpus,
129+
)
130+
131+
mojo_host_platform(
132+
name = "mojo_host_platform",
133+
gpu_mapping = gpu_toolchain.gpu_mapping,
115134
)
116135

117136
_mojo_toolchain_hub(
@@ -140,10 +159,57 @@ _toolchain_tag = tag_class(
140159
},
141160
)
142161

162+
_gpu_toolchains_tag = tag_class(
163+
doc = "Tags for configuring Mojo GPU toolchains.",
164+
attrs = {
165+
"supported_gpus": attr.string_dict(
166+
default = {
167+
"780M": "amdgpu:gfx1103",
168+
"a10": "nvidia:86",
169+
"a100": "nvidia:80",
170+
"a3000": "nvidia:86",
171+
"b100": "nvidia:100a",
172+
"b200": "nvidia:100a",
173+
"h100": "nvidia:90a",
174+
"h200": "nvidia:90a",
175+
"l4": "nvidia:89",
176+
"mi300x": "amdgpu:gfx942",
177+
"mi325": "amdgpu:gfx942",
178+
"rtx5090": "nvidia:120a",
179+
},
180+
doc = "The GPUs supported by this toolchain, mapping to Mojo's target accelerators.",
181+
),
182+
"gpu_mapping": attr.string_dict(
183+
default = {
184+
" A10G": "a10",
185+
"A100-": "a100",
186+
" H100 ": "h100",
187+
" H200 ": "h200",
188+
" L4 ": "L4",
189+
" Ada ": "L4",
190+
" A3000 ": "a3000",
191+
"B100": "b100",
192+
"B200": "b200",
193+
" RTX 5090": "rtx5090",
194+
"Laptop GPU": "",
195+
"RTX 4070 Ti": "",
196+
"RTX 4080 SUPER": "",
197+
"NVIDIA GeForce RTX 3090": "",
198+
"MI300X": "mi300x",
199+
"MI325": "mi325",
200+
"Navi": "radeon",
201+
"AMD Radeon Graphics": "radeon",
202+
},
203+
doc = "The output from nvidia-smi or rocm-smi to the corresponding GPU name in SUPPORTED_GPUS.",
204+
),
205+
},
206+
)
207+
143208
mojo = module_extension(
144209
doc = "Mojo toolchain extension.",
145210
implementation = _mojo_impl,
146211
tag_classes = {
147212
"toolchain": _toolchain_tag,
213+
"gpu_toolchains": _gpu_toolchains_tag,
148214
},
149215
)

mojo/mojo_host_platform.bzl

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
"""Setup a host platform that takes into account current GPU hardware"""
2+
3+
def _verbose_log(rctx, msg):
4+
if rctx.getenv("MOJO_VERBOSE_GPU_DETECT"):
5+
# buildifier: disable=print
6+
print(msg)
7+
8+
def _log_result(rctx, binary, result):
9+
_verbose_log(
10+
rctx,
11+
"\n------ {}:\nexit status: {}\nstdout: {}\nstderr: {}\n------ end gpu-query info"
12+
.format(binary, result.return_code, result.stdout, result.stderr),
13+
)
14+
15+
def _get_amdgpu_constraint(series, gpu_mapping):
16+
for gpu_name, constraint in gpu_mapping.items():
17+
if gpu_name in series:
18+
if constraint:
19+
return "@mojo_gpu_toolchains//:{}_gpu".format(constraint)
20+
else:
21+
return None
22+
23+
fail("Unrecognized amd-smi/rocm-smi output, please add it to your gpu_mapping in the MODULE.bazel file: {}".format(series))
24+
25+
def _get_rocm_constraint(blob, gpu_mapping):
26+
for value in blob.values():
27+
series = value["Card Series"]
28+
return _get_amdgpu_constraint(series, gpu_mapping)
29+
fail("Unrecognized rocm-smi output, please report: {}".format(blob))
30+
31+
def _get_amd_constraint(blob, gpu_mapping):
32+
for value in blob:
33+
series = value["asic"]["market_name"]
34+
return _get_amdgpu_constraint(series, gpu_mapping)
35+
fail("Unrecognized amd-smi output, please report: {}".format(blob))
36+
37+
def _get_nvidia_constraint(lines, gpu_mapping):
38+
line = lines[0]
39+
for gpu_name, constraint in gpu_mapping.items():
40+
if gpu_name in line:
41+
if constraint:
42+
return "@mojo_gpu_toolchains//:{}_gpu".format(constraint)
43+
else:
44+
return None
45+
46+
fail("Unrecognized nvidia-smi output, please add it to your gpu_mapping in the MODULE.bazel file: {}".format(lines))
47+
48+
def _impl(rctx):
49+
constraints = []
50+
51+
if rctx.os.name == "linux" and rctx.os.arch == "amd64":
52+
# A system may have both rocm-smi and nvidia-smi installed, check both.
53+
nvidia_smi = rctx.which("nvidia-smi")
54+
55+
# amd-smi supersedes rocm-smi
56+
amd_smi = rctx.which("amd-smi")
57+
rocm_smi = rctx.which("rocm-smi")
58+
59+
_verbose_log(rctx, "nvidia-smi path: {}, rocm-smi path: {}, amd-smi path: {}".format(nvidia_smi, rocm_smi, amd_smi))
60+
61+
# NVIDIA
62+
if nvidia_smi:
63+
result = rctx.execute([nvidia_smi, "--query-gpu=gpu_name", "--format=csv,noheader"])
64+
_log_result(rctx, nvidia_smi, result)
65+
if result.return_code == 0:
66+
lines = result.stdout.splitlines()
67+
if len(lines) == 0:
68+
fail("nvidia-smi succeeded but had no GPUs, please report this issue")
69+
70+
constraint = _get_nvidia_constraint(lines, rctx.attr.gpu_mapping)
71+
if constraint:
72+
constraints.extend([
73+
"@mojo_gpu_toolchains//:nvidia_gpu",
74+
"@mojo_gpu_toolchains//:has_gpu",
75+
constraint,
76+
])
77+
78+
if len(lines) > 1:
79+
constraints.append("@mojo_gpu_toolchains//:has_multi_gpu")
80+
if len(lines) >= 4:
81+
constraints.append("@mojo_gpu_toolchains//:has_4_gpus")
82+
83+
# AMD
84+
if amd_smi:
85+
result = rctx.execute([amd_smi, "static", "--json"])
86+
_log_result(rctx, amd_smi, result)
87+
88+
if result.return_code == 0:
89+
constraints.extend([
90+
"@mojo_gpu_toolchains//:amd_gpu",
91+
"@mojo_gpu_toolchains//:has_gpu",
92+
])
93+
94+
blob = json.decode(result.stdout)
95+
if len(blob) == 0:
96+
fail("amd-smi succeeded but didn't actually have any GPUs, please report this issue")
97+
98+
constraints.append(_get_amd_constraint(blob, rctx.attr.gpu_mapping))
99+
if len(blob) > 1:
100+
constraints.append("@mojo_gpu_toolchains//:has_multi_gpu")
101+
if len(blob) >= 4:
102+
constraints.append("@mojo_gpu_toolchains//:has_4_gpus")
103+
104+
elif rocm_smi:
105+
result = rctx.execute([rocm_smi, "--json", "--showproductname"])
106+
_log_result(rctx, rocm_smi, result)
107+
108+
if result.return_code == 0:
109+
constraints.extend([
110+
"@mojo_gpu_toolchains//:amd_gpu",
111+
"@mojo_gpu_toolchains//:has_gpu",
112+
])
113+
114+
blob = json.decode(result.stdout)
115+
if len(blob.keys()) == 0:
116+
fail("rocm-smi succeeded but didn't actually have any GPUs, please report this issue")
117+
118+
constraints.append(_get_rocm_constraint(blob, rctx.attr.gpu_mapping))
119+
if len(blob.keys()) > 1:
120+
constraints.append("@mojo_gpu_toolchains//:has_multi_gpu")
121+
if len(blob.keys()) >= 4:
122+
constraints.append("@mojo_gpu_toolchains//:has_4_gpus")
123+
124+
rctx.file("WORKSPACE.bazel", "workspace(name = {})".format(rctx.attr.name))
125+
rctx.file("BUILD.bazel", """
126+
platform(
127+
name = "mojo_host_platform",
128+
parents = ["@platforms//host"],
129+
visibility = ["//visibility:public"],
130+
constraint_values = [{constraints}],
131+
exec_properties = {{
132+
"no-remote-exec": "1",
133+
}},
134+
)
135+
""".format(constraints = ", ".join(['"{}"'.format(x) for x in constraints])))
136+
137+
mojo_host_platform = repository_rule(
138+
implementation = _impl,
139+
configure = True,
140+
environ = [
141+
"MOJO_VERBOSE_GPU_DETECT",
142+
],
143+
attrs = {
144+
"gpu_mapping": attr.string_dict(
145+
doc = "A dictionary of GPU strings from nvidia-smi or amd-smi, mapped to supported GPUs defined by mojo.gpu_toolchains()",
146+
),
147+
},
148+
)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
"""Bazel toolchain representing the currently targeted GPU hardware"""
2+
3+
load("//mojo:providers.bzl", "MojoGPUToolchainInfo")
4+
5+
def _mojo_gpu_toolchain_impl(ctx):
6+
brand = ctx.attr.target_accelerator.split(":")[0]
7+
return [
8+
platform_common.ToolchainInfo(
9+
mojo_gpu_toolchain_info = MojoGPUToolchainInfo(
10+
brand = brand,
11+
has_4_gpus = ctx.attr.has_4_gpus,
12+
multi_gpu = ctx.attr.multi_gpu,
13+
name = ctx.attr.name,
14+
target_accelerator = ctx.attr.target_accelerator,
15+
),
16+
),
17+
]
18+
19+
mojo_gpu_toolchain = rule(
20+
implementation = _mojo_gpu_toolchain_impl,
21+
attrs = {
22+
"target_accelerator": attr.string(mandatory = True),
23+
"multi_gpu": attr.bool(mandatory = True),
24+
"has_4_gpus": attr.bool(mandatory = True),
25+
},
26+
)

0 commit comments

Comments
 (0)