Skip to content

Commit 3ace346

Browse files
authored
feat: add cuda_redist_json repo rule (#286)
`cuda_redist_json` repo rule download the `redistrib.json` (redist_json) from the `urls`, or download from nvidia's default repo if only `version` is specified. In the repo, we generate a `redist.bzl` file. The `redist.bzl` contains macros with wrapped `cuda_component` repo rules of the components specified in the `cuda_redist_json`'s `components` attribute. For example, cuda_redist_json( name = "rules_cuda_redist_json", components = [ "cccl", "cudart", "nvcc", ], version = "12.6.3", ) - Downloads `redistrib.json` from `https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.6.3.json` - Generates a `redist.bzl` with the content as follows: def rules_cuda_components(): cuda_component( name = "local_cuda_cccl_v12.6.77", component_name = "cccl", # sha256, strip_prefix and urls automatically filled by reading redist.json ) cuda_component( name = "local_cuda_cudart_v12.6.77", component_name = "cudart", # sha256, strip_prefix and urls automatically filled by reading redist.json ) cuda_component( name = "local_cuda_nvcc_v12.6.85", component_name = "nvcc", # sha256, strip_prefix and urls automatically filled by reading redist.json ) return {"cccl": "@local_cuda_cccl_v12.6.77", "cudart": "@local_cuda_cudart_v12.6.77", "nvcc": "@local_cuda_nvcc_v12.6.85"} def rules_cuda_components_and_toolchains(register_toolchains = False): components_mapping = rules_cuda_components() rules_cuda_toolchains( components_mapping= components_mapping, register_toolchains = register_toolchains, version = "12.6.3", ) User then load("@rules_cuda_redist_json//:redist.bzl", "rules_cuda_components_and_toolchains") rules_cuda_components_and_toolchains(register_toolchains = True) Only WORKSPACE based project is addressed in this PR.
1 parent 942ad49 commit 3ace346

File tree

11 files changed

+268
-1
lines changed

11 files changed

+268
-1
lines changed

cuda/extensions.bzl

+32-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Entry point for extensions used by bzlmod."""
22

33
load("//cuda/private:compat.bzl", "components_mapping_compat")
4-
load("//cuda/private:repositories.bzl", "cuda_component", "local_cuda")
4+
load("//cuda/private:repositories.bzl", "cuda_component", "cuda_redist_json", "local_cuda")
55

66
cuda_component_tag = tag_class(attrs = {
77
"name": attr.string(mandatory = True, doc = "Repo name for the deliverable cuda_component"),
@@ -33,6 +33,30 @@ cuda_component_tag = tag_class(attrs = {
3333
),
3434
})
3535

36+
cuda_redist_json_tag = tag_class(attrs = {
37+
"name": attr.string(mandatory = True, doc = "Repo name for the cuda_redist_json"),
38+
"components": attr.string_list(mandatory = True, doc = "components to be used"),
39+
"integrity": attr.string(
40+
doc = "Expected checksum in Subresource Integrity format of the file downloaded. " +
41+
"This must match the checksum of the file downloaded.",
42+
),
43+
"sha256": attr.string(
44+
doc = "The expected SHA-256 of the file downloaded. " +
45+
"This must match the SHA-256 of the file downloaded.",
46+
),
47+
"urls": attr.string_list(
48+
doc = "A list of URLs to a file that will be made available to Bazel. " +
49+
"Each entry must be a file, http or https URL. Redirections are followed. " +
50+
"Authentication is not supported. " +
51+
"URLs are tried in order until one succeeds, so you should list local mirrors first. " +
52+
"If all downloads fail, the rule will fail.",
53+
),
54+
"version": attr.string(
55+
doc = "Generate a URL by using the specified version." +
56+
"This URL will be tried after all URLs specified in the `urls` attribute.",
57+
),
58+
})
59+
3660
cuda_toolkit_tag = tag_class(attrs = {
3761
"name": attr.string(mandatory = True, doc = "Name for the toolchain repository", default = "local_cuda"),
3862
"toolkit_path": attr.string(
@@ -70,17 +94,23 @@ def _impl(module_ctx):
7094
# Toolchain configuration is only allowed in the root module, or in rules_cuda.
7195
root, rules_cuda = _find_modules(module_ctx)
7296
components = None
97+
redist_jsons = None
7398
toolkits = None
7499
if root.tags.toolkit:
75100
components = root.tags.component
101+
redist_jsons = root.tags.redist_json
76102
toolkits = root.tags.toolkit
77103
else:
78104
components = rules_cuda.tags.component
105+
redist_jsons = rules_cuda.tags.redist_json
79106
toolkits = rules_cuda.tags.toolkit
80107

81108
for component in components:
82109
cuda_component(**_module_tag_to_dict(component))
83110

111+
for redist_json in redist_jsons:
112+
cuda_redist_json(**_module_tag_to_dict(redist_json))
113+
84114
registrations = {}
85115
for toolkit in toolkits:
86116
if toolkit.name in registrations.keys():
@@ -97,6 +127,7 @@ toolchain = module_extension(
97127
implementation = _impl,
98128
tag_classes = {
99129
"component": cuda_component_tag,
130+
"redist_json": cuda_redist_json_tag,
100131
"toolkit": cuda_toolkit_tag,
101132
},
102133
)

cuda/private/repositories.bzl

+68
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,74 @@ def default_components_mapping(components):
337337
"""
338338
return {c: "@local_cuda_" + c for c in components}
339339

340+
def _cuda_redist_json_impl(repository_ctx):
341+
the_url = None # the url that successfully fetch redist json, we then use it to fetch deliverables
342+
urls = [u for u in repository_ctx.attr.urls]
343+
344+
redist_ver = repository_ctx.attr.version
345+
if redist_ver:
346+
urls.append("https://developer.download.nvidia.com/compute/cuda/redist/redistrib_{}.json".format(redist_ver))
347+
348+
if len(urls) == 0:
349+
fail("`urls` or `version` must be specified.")
350+
351+
for url in urls:
352+
ret = repository_ctx.download(
353+
output = "redist.json",
354+
integrity = repository_ctx.attr.integrity,
355+
sha256 = repository_ctx.attr.sha256,
356+
url = url,
357+
)
358+
if ret.success:
359+
the_url = url
360+
break
361+
362+
if the_url == None:
363+
fail("Failed to retrieve the redist json file.")
364+
365+
# convert redist.json to list of spec (list of dicts with cuda_components attrs)
366+
specs = []
367+
redist = json.decode(repository_ctx.read("redist.json"))
368+
if not redist_ver:
369+
redist_ver = redist["release_label"]
370+
for c in repository_ctx.attr.components:
371+
c_full = FULL_COMPONENT_NAME[c]
372+
os = None
373+
if _is_linux(repository_ctx):
374+
os = "linux"
375+
elif _is_windows(repository_ctx):
376+
os = "windows"
377+
378+
arch = "x86_64" # TODO: support cross compiling
379+
platform = "{os}-{arch}".format(os = os, arch = arch)
380+
381+
payload = redist[c_full][platform]
382+
payload_relative_path = payload["relative_path"]
383+
payload_url = the_url.rsplit("/", 1)[0] + "/" + payload_relative_path
384+
archive_name = payload_relative_path.rsplit("/", 1)[1].split("-archive.")[0] + "-archive"
385+
386+
specs.append({
387+
"component_name": c,
388+
"urls": [payload_url],
389+
"sha256": payload["sha256"],
390+
"strip_prefix": archive_name,
391+
"version": redist[c_full]["version"],
392+
})
393+
394+
template_helper.generate_redist_bzl(repository_ctx, specs, redist_ver)
395+
repository_ctx.symlink(Label("//cuda/private:templates/BUILD.redist_json"), "BUILD")
396+
397+
cuda_redist_json = repository_rule(
398+
implementation = _cuda_redist_json_impl,
399+
attrs = {
400+
"components": attr.string_list(mandatory = True),
401+
"integrity": attr.string(mandatory = False),
402+
"sha256": attr.string(mandatory = False),
403+
"urls": attr.string_list(mandatory = False),
404+
"version": attr.string(mandatory = False),
405+
},
406+
)
407+
340408
def rules_cuda_dependencies():
341409
"""Populate the dependencies for rules_cuda. This will setup other bazel rules as workspace dependencies"""
342410
maybe(

cuda/private/template_helper.bzl

+45
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,50 @@ def _generate_defs_bzl(repository_ctx, is_local_ctk):
8989
}
9090
repository_ctx.template("defs.bzl", tpl_label, substitutions = substitutions, executable = False)
9191

92+
def _generate_redist_bzl(repository_ctx, component_specs, redist_version):
93+
"""Generate `@rules_cuda_redist_json//:redist.bzl`
94+
95+
Args:
96+
repository_ctx: repository_ctx
97+
component_specs: list of dict, dict keys are component_name, urls, sha256, strip_prefix and version
98+
"""
99+
100+
rules_cuda_components_body = []
101+
mapping = {}
102+
103+
component_tpl = """cuda_component(
104+
name = "{repo_name}",
105+
component_name = "{component_name}",
106+
sha256 = {sha256},
107+
strip_prefix = {strip_prefix},
108+
urls = {urls},
109+
)"""
110+
111+
for spec in component_specs:
112+
repo_name = "local_cuda_" + spec["component_name"]
113+
version = spec.get("version", None)
114+
if version != None:
115+
repo_name = repo_name + "_v" + version
116+
117+
rules_cuda_components_body.append(
118+
component_tpl.format(
119+
repo_name = repo_name,
120+
component_name = spec["component_name"],
121+
sha256 = repr(spec["sha256"]),
122+
strip_prefix = repr(spec["strip_prefix"]),
123+
urls = repr(spec["urls"]),
124+
),
125+
)
126+
mapping[spec["component_name"]] = "@" + repo_name
127+
128+
tpl_label = Label("//cuda/private:templates/redist.bzl.tpl")
129+
substitutions = {
130+
"%{rules_cuda_components_body}": "\n\n ".join(rules_cuda_components_body),
131+
"%{components_mapping}": repr(mapping),
132+
"%{version}": redist_version,
133+
}
134+
repository_ctx.template("redist.bzl", tpl_label, substitutions = substitutions, executable = False)
135+
92136
def _generate_toolchain_build(repository_ctx, cuda):
93137
tpl_label = Label(
94138
"//cuda/private:templates/BUILD.local_toolchain_" +
@@ -127,6 +171,7 @@ def _generate_toolchain_clang_build(repository_ctx, cuda, clang_path):
127171
template_helper = struct(
128172
generate_build = _generate_build,
129173
generate_defs_bzl = _generate_defs_bzl,
174+
generate_redist_bzl = _generate_redist_bzl,
130175
generate_toolchain_build = _generate_toolchain_build,
131176
generate_toolchain_clang_build = _generate_toolchain_clang_build,
132177
)
+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package(
2+
default_visibility = ["//visibility:public"],
3+
)
4+
5+
filegroup(
6+
name = "redist_bzl",
7+
srcs = [":redist.bzl"],
8+
)
9+
10+
filegroup(
11+
name = "redist_json",
12+
srcs = [":redist.json"],
13+
)
14+
15+
exports_files([
16+
"redist.bzl",
17+
"redist.json",
18+
])

cuda/private/templates/redist.bzl.tpl

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
load("@rules_cuda//cuda:repositories.bzl", "cuda_component", "rules_cuda_toolchains")
2+
3+
def rules_cuda_components():
4+
# See template_helper.generate_redist_bzl(...) for body generation logic
5+
%{rules_cuda_components_body}
6+
7+
return %{components_mapping}
8+
9+
def rules_cuda_components_and_toolchains(register_toolchains = False):
10+
components_mapping = rules_cuda_components()
11+
rules_cuda_toolchains(
12+
components_mapping= components_mapping,
13+
register_toolchains = register_toolchains,
14+
version = "%{version}",
15+
)

cuda/repositories.bzl

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
load(
22
"//cuda/private:repositories.bzl",
33
_cuda_component = "cuda_component",
4+
_cuda_redist_json = "cuda_redist_json",
45
_default_components_mapping = "default_components_mapping",
56
_local_cuda = "local_cuda",
67
_rules_cuda_dependencies = "rules_cuda_dependencies",
@@ -10,6 +11,7 @@ load("//cuda/private:toolchain.bzl", _register_detected_cuda_toolchains = "regis
1011

1112
# rules
1213
cuda_component = _cuda_component
14+
cuda_redist_json = _cuda_redist_json
1315
local_cuda = _local_cuda
1416

1517
# macros

tests/integration/test_all.sh

+11
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,14 @@ pushd "$this_dir/toolchain_components"
7171
bazel build --enable_bzlmod //:use_rule
7272
bazel clean && bazel shutdown
7373
popd
74+
75+
# toolchain configured with deliverables (redistrib.json with workspace)
76+
pushd "$this_dir/toolchain_redist_json"
77+
bazel build --enable_workspace //... --@rules_cuda//cuda:enable=False
78+
bazel build --enable_workspace //... --@rules_cuda//cuda:enable=True
79+
bazel build --enable_workspace //:optinally_use_rule --@rules_cuda//cuda:enable=False
80+
bazel build --enable_workspace //:optinally_use_rule --@rules_cuda//cuda:enable=True
81+
bazel build --enable_workspace //:use_library
82+
bazel build --enable_workspace //:use_rule
83+
bazel clean && bazel shutdown
84+
popd
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../BUILD.to_symlink
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
module(name = "bzlmod_components")
2+
3+
# FIXME: cuda_redist_json is not exposed in bzlmod now. Fallback to manually specified components for tests
4+
bazel_dep(name = "rules_cuda", version = "0.0.0")
5+
local_path_override(
6+
module_name = "rules_cuda",
7+
path = "../../..",
8+
)
9+
10+
cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain")
11+
cuda.component(
12+
name = "local_cuda_cccl",
13+
component_name = "cccl",
14+
sha256 = "9c3145ef01f73e50c0f5fcf923f0899c847f487c529817daa8f8b1a3ecf20925",
15+
strip_prefix = "cuda_cccl-linux-x86_64-12.6.77-archive",
16+
urls = ["https://developer.download.nvidia.com/compute/cuda/redist/cuda_cccl/linux-x86_64/cuda_cccl-linux-x86_64-12.6.77-archive.tar.xz"],
17+
)
18+
cuda.component(
19+
name = "local_cuda_cudart",
20+
component_name = "cudart",
21+
sha256 = "f74689258a60fd9c5bdfa7679458527a55e22442691ba678dcfaeffbf4391ef9",
22+
strip_prefix = "cuda_cudart-linux-x86_64-12.6.77-archive",
23+
urls = ["https://developer.download.nvidia.com/compute/cuda/redist/cuda_cudart/linux-x86_64/cuda_cudart-linux-x86_64-12.6.77-archive.tar.xz"],
24+
)
25+
cuda.component(
26+
name = "local_cuda_nvcc",
27+
component_name = "nvcc",
28+
sha256 = "840deff234d9bef20d6856439c49881cb4f29423b214f9ecd2fa59b7ac323817",
29+
strip_prefix = "cuda_nvcc-linux-x86_64-12.6.85-archive",
30+
urls = ["https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/linux-x86_64/cuda_nvcc-linux-x86_64-12.6.85-archive.tar.xz"],
31+
)
32+
cuda.toolkit(
33+
name = "local_cuda",
34+
components_mapping = {
35+
"cccl": "@local_cuda_cccl",
36+
"cudart": "@local_cuda_cudart",
37+
"nvcc": "@local_cuda_nvcc",
38+
},
39+
version = "12.6",
40+
)
41+
use_repo(
42+
cuda,
43+
"local_cuda",
44+
"local_cuda_cccl",
45+
"local_cuda_cudart",
46+
"local_cuda_nvcc",
47+
)
48+
49+
bazel_dep(name = "rules_cuda_examples", version = "0.0.0")
50+
local_path_override(
51+
module_name = "rules_cuda_examples",
52+
path = "../../../examples",
53+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
local_repository(
2+
name = "rules_cuda",
3+
path = "../../..",
4+
)
5+
6+
# buildifier: disable=load-on-top
7+
load("@rules_cuda//cuda:repositories.bzl", "cuda_redist_json", "rules_cuda_dependencies")
8+
9+
rules_cuda_dependencies()
10+
11+
cuda_redist_json(
12+
name = "rules_cuda_redist_json",
13+
components = [
14+
"cccl",
15+
"cudart",
16+
"nvcc",
17+
],
18+
version = "12.6.3",
19+
)
20+
21+
load("@rules_cuda_redist_json//:redist.bzl", "rules_cuda_components_and_toolchains")
22+
23+
rules_cuda_components_and_toolchains(register_toolchains = True)

tests/integration/toolchain_redist_json/WORKSPACE.bzlmod

Whitespace-only changes.

0 commit comments

Comments
 (0)