@@ -12,29 +12,37 @@ def _log_result(rctx, binary, result):
1212 .format (binary , result .return_code , result .stdout , result .stderr ),
1313 )
1414
15- def _get_amdgpu_constraint (series , gpu_mapping ):
15+ def _fail (rctx , msg ):
16+ if rctx .getenv ("MOJO_IGNORE_UNKNOWN_GPUS" ) == "1" :
17+ # buildifier: disable=print
18+ print ("WARNING: ignoring unknown GPU, to support it, add it to the gpu_mapping in the MODULE.bazel: {}" .format (msg ))
19+ else :
20+ fail (msg )
21+
22+ def _get_amdgpu_constraint (rctx , series , gpu_mapping ):
1623 for gpu_name , constraint in gpu_mapping .items ():
1724 if gpu_name in series :
1825 if constraint :
1926 return "@mojo_gpu_toolchains//:{}_gpu" .format (constraint )
2027 else :
2128 return None
2229
23- fail ("Unrecognized amd-smi/rocm-smi output, please add it to your gpu_mapping in the MODULE.bazel file: {}" .format (series ))
30+ _fail (rctx , "Unrecognized amd-smi/rocm-smi output, please add it to your gpu_mapping in the MODULE.bazel file: {}" .format (series ))
31+ return None
2432
25- def _get_rocm_constraint (blob , gpu_mapping ):
33+ def _get_rocm_constraint (rctx , blob , gpu_mapping ):
2634 for value in blob .values ():
2735 series = value ["Card Series" ]
28- return _get_amdgpu_constraint (series , gpu_mapping )
36+ return _get_amdgpu_constraint (rctx , series , gpu_mapping )
2937 fail ("Unrecognized rocm-smi output, please report: {}" .format (blob ))
3038
31- def _get_amd_constraint (blob , gpu_mapping ):
39+ def _get_amd_constraint (rctx , blob , gpu_mapping ):
3240 for value in blob :
3341 series = value ["asic" ]["market_name" ]
34- return _get_amdgpu_constraint (series , gpu_mapping )
42+ return _get_amdgpu_constraint (rctx , series , gpu_mapping )
3543 fail ("Unrecognized amd-smi output, please report: {}" .format (blob ))
3644
37- def _get_nvidia_constraint (lines , gpu_mapping ):
45+ def _get_nvidia_constraint (rctx , lines , gpu_mapping ):
3846 line = lines [0 ]
3947 for gpu_name , constraint in gpu_mapping .items ():
4048 if gpu_name in line :
@@ -43,7 +51,8 @@ def _get_nvidia_constraint(lines, gpu_mapping):
4351 else :
4452 return None
4553
46- fail ("Unrecognized nvidia-smi output, please add it to your gpu_mapping in the MODULE.bazel file: {}" .format (lines ))
54+ _fail (rctx , "Unrecognized nvidia-smi output, please add it to your gpu_mapping in the MODULE.bazel file: {}" .format (lines ))
55+ return None
4756
4857def _impl (rctx ):
4958 constraints = []
@@ -67,59 +76,63 @@ def _impl(rctx):
6776 if len (lines ) == 0 :
6877 fail ("nvidia-smi succeeded but had no GPUs, please report this issue" )
6978
70- constraint = _get_nvidia_constraint (lines , rctx .attr .gpu_mapping )
79+ constraint = _get_nvidia_constraint (rctx , lines , rctx .attr .gpu_mapping )
7180 if constraint :
7281 constraints .extend ([
7382 "@mojo_gpu_toolchains//:nvidia_gpu" ,
7483 "@mojo_gpu_toolchains//:has_gpu" ,
7584 constraint ,
7685 ])
7786
78- if len (lines ) > 1 :
79- constraints .append ("@mojo_gpu_toolchains//:has_multi_gpu" )
80- if len (lines ) >= 4 :
81- constraints .append ("@mojo_gpu_toolchains//:has_4_gpus" )
87+ if len (lines ) > 1 :
88+ constraints .append ("@mojo_gpu_toolchains//:has_multi_gpu" )
89+ if len (lines ) >= 4 :
90+ constraints .append ("@mojo_gpu_toolchains//:has_4_gpus" )
8291
8392 # AMD
8493 if amd_smi :
8594 result = rctx .execute ([amd_smi , "static" , "--json" ])
8695 _log_result (rctx , amd_smi , result )
8796
8897 if result .return_code == 0 :
89- constraints .extend ([
90- "@mojo_gpu_toolchains//:amd_gpu" ,
91- "@mojo_gpu_toolchains//:has_gpu" ,
92- ])
93-
9498 blob = json .decode (result .stdout )
9599 if len (blob ) == 0 :
96100 fail ("amd-smi succeeded but didn't actually have any GPUs, please report this issue" )
97101
98- constraints .append (_get_amd_constraint (blob , rctx .attr .gpu_mapping ))
99- if len (blob ) > 1 :
100- constraints .append ("@mojo_gpu_toolchains//:has_multi_gpu" )
101- if len (blob ) >= 4 :
102- constraints .append ("@mojo_gpu_toolchains//:has_4_gpus" )
102+ amd_constraint = _get_amd_constraint (rctx , blob , rctx .attr .gpu_mapping )
103+ if amd_constraint :
104+ constraints .extend ([
105+ amd_constraint ,
106+ "@mojo_gpu_toolchains//:amd_gpu" ,
107+ "@mojo_gpu_toolchains//:has_gpu" ,
108+ ])
109+
110+ if len (blob ) > 1 :
111+ constraints .append ("@mojo_gpu_toolchains//:has_multi_gpu" )
112+ if len (blob ) >= 4 :
113+ constraints .append ("@mojo_gpu_toolchains//:has_4_gpus" )
103114
104115 elif rocm_smi :
105116 result = rctx .execute ([rocm_smi , "--json" , "--showproductname" ])
106117 _log_result (rctx , rocm_smi , result )
107118
108119 if result .return_code == 0 :
109- constraints .extend ([
110- "@mojo_gpu_toolchains//:amd_gpu" ,
111- "@mojo_gpu_toolchains//:has_gpu" ,
112- ])
113-
114120 blob = json .decode (result .stdout )
115121 if len (blob .keys ()) == 0 :
116122 fail ("rocm-smi succeeded but didn't actually have any GPUs, please report this issue" )
117123
118- constraints .append (_get_rocm_constraint (blob , rctx .attr .gpu_mapping ))
119- if len (blob .keys ()) > 1 :
120- constraints .append ("@mojo_gpu_toolchains//:has_multi_gpu" )
121- if len (blob .keys ()) >= 4 :
122- constraints .append ("@mojo_gpu_toolchains//:has_4_gpus" )
124+ rocm_constraint = _get_rocm_constraint (rctx , blob , rctx .attr .gpu_mapping )
125+ if rocm_constraint :
126+ constraints .extend ([
127+ rocm_constraint ,
128+ "@mojo_gpu_toolchains//:amd_gpu" ,
129+ "@mojo_gpu_toolchains//:has_gpu" ,
130+ ])
131+
132+ if len (blob .keys ()) > 1 :
133+ constraints .append ("@mojo_gpu_toolchains//:has_multi_gpu" )
134+ if len (blob .keys ()) >= 4 :
135+ constraints .append ("@mojo_gpu_toolchains//:has_4_gpus" )
123136
124137 rctx .file ("WORKSPACE.bazel" , "workspace(name = {})" .format (rctx .attr .name ))
125138 rctx .file ("BUILD.bazel" , """
@@ -138,6 +151,7 @@ mojo_host_platform = repository_rule(
138151 implementation = _impl ,
139152 configure = True ,
140153 environ = [
154+ "MOJO_IGNORE_UNKNOWN_GPUS" ,
141155 "MOJO_VERBOSE_GPU_DETECT" ,
142156 ],
143157 attrs = {
0 commit comments