Skip to content

Commit 3cac1eb

Browse files
committed
support sm_100 and llvm v19
1 parent caaef11 commit 3cac1eb

40 files changed

+4011
-2683
lines changed

.devcontainer.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@
1717
"features": {
1818
"ghcr.io/devcontainers/features/git:1": {}
1919
}
20-
}
20+
}

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@ book
22
/target
33
Cargo.lock
44
**/.vscode
5-
.devcontainer
5+
.devcontainer

container/ubuntu22-cuda12/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
FROM nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04
1+
FROM nvidia/cuda:12.9.0-cudnn-devel-ubuntu22.04
22

33
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get -qq -y install \
44
build-essential \

container/ubuntu24-cuda12/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
FROM nvidia/cuda:12.8.1-cudnn-devel-ubuntu24.04
1+
FROM nvidia/cuda:12.9.0-cudnn-devel-ubuntu24.04
22

33
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get -qq -y install \
44
build-essential \

crates/cuda_std/src/cfg.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ pub enum ComputeCapability {
1616
Compute72,
1717
Compute75,
1818
Compute80,
19+
Compute86,
20+
Compute87,
21+
Compute89,
22+
Compute90,
23+
Compute100
1924
}
2025

2126
impl ComputeCapability {
@@ -42,6 +47,11 @@ impl ComputeCapability {
4247
"720" => ComputeCapability::Compute72,
4348
"750" => ComputeCapability::Compute75,
4449
"800" => ComputeCapability::Compute80,
50+
"860" => ComputeCapability::Compute86, // Ampere (RTX 30 series, A100)
51+
"870" => ComputeCapability::Compute87, // Ampere (Jetson AGX Orin)
52+
"890" => ComputeCapability::Compute89, // Ada Lovelace (RTX 40 series)
53+
"900" => ComputeCapability::Compute90, // Hopper (H100)
54+
"1000" => ComputeCapability::Compute100, // Blackwell (RTX 50 series, H200, B100)
4555
_ => panic!("CUDA_ARCH had an invalid value"),
4656
}
4757
}

crates/cust/src/module.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ pub enum JitTarget {
5656
Compute75 = 75,
5757
Compute80 = 80,
5858
Compute86 = 86,
59+
Compute87 = 87,
60+
Compute89 = 89,
61+
Compute90 = 90,
62+
Compute100 = 100,
5963
}
6064

6165
/// How to handle cases where a loaded module's data does not contain an exact match for the

crates/nvvm/src/lib.rs

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ use std::{
44
ffi::{CStr, CString},
55
fmt::Display,
66
mem::MaybeUninit,
7-
ptr::null_mut,
87
str::FromStr,
98
};
109

@@ -255,6 +254,11 @@ impl FromStr for NvvmOption {
255254
"72" => NvvmArch::Compute72,
256255
"75" => NvvmArch::Compute75,
257256
"80" => NvvmArch::Compute80,
257+
"86" => NvvmArch::Compute86,
258+
"87" => NvvmArch::Compute87,
259+
"89" => NvvmArch::Compute89,
260+
"90" => NvvmArch::Compute90,
261+
"100" => NvvmArch::Compute100,
258262
_ => return Err("unknown arch"),
259263
};
260264
Self::Arch(arch)
@@ -279,6 +283,11 @@ pub enum NvvmArch {
279283
Compute72,
280284
Compute75,
281285
Compute80,
286+
Compute86,
287+
Compute87,
288+
Compute89,
289+
Compute90,
290+
Compute100,
282291
}
283292

284293
impl Display for NvvmArch {
@@ -403,8 +412,21 @@ impl NvvmProgram {
403412

404413
/// Verify the program without actually compiling it. In the case of invalid IR, you can find
405414
/// more detailed error info by calling [`compiler_log`](Self::compiler_log).
406-
pub fn verify(&self) -> Result<(), NvvmError> {
407-
unsafe { nvvm_sys::nvvmVerifyProgram(self.raw, 0, null_mut()).to_result() }
415+
pub fn verify(&self, options: &[NvvmOption]) -> Result<(), NvvmError> {
416+
let option_strings: Vec<_> = options.iter().map(|opt| opt.to_string()).collect();
417+
let option_cstrings: Vec<_> = option_strings.iter()
418+
.map(|s| std::ffi::CString::new(s.as_str()).unwrap())
419+
.collect();
420+
let mut option_ptrs: Vec<_> = option_cstrings.iter()
421+
.map(|cs| cs.as_ptr())
422+
.collect();
423+
unsafe {
424+
nvvm_sys::nvvmVerifyProgram(
425+
self.raw,
426+
option_ptrs.len() as i32,
427+
option_ptrs.as_mut_ptr()
428+
).to_result()
429+
}
408430
}
409431
}
410432

@@ -433,6 +455,11 @@ mod tests {
433455
"-arch=compute_72",
434456
"-arch=compute_75",
435457
"-arch=compute_80",
458+
"-arch=compute_86",
459+
"-arch=compute_87",
460+
"-arch=compute_89",
461+
"-arch=compute_90",
462+
"-arch=compute_100",
436463
"-ftz=1",
437464
"-prec-sqrt=0",
438465
"-prec-div=0",
@@ -454,6 +481,11 @@ mod tests {
454481
Arch(Compute72),
455482
Arch(Compute75),
456483
Arch(Compute80),
484+
Arch(Compute86),
485+
Arch(Compute87),
486+
Arch(Compute89),
487+
Arch(Compute90),
488+
Arch(Compute100),
457489
Ftz,
458490
FastSqrt,
459491
FastDiv,

crates/rustc_codegen_nvvm/build.rs

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,45 @@ pub fn tracked_env_var_os<K: AsRef<OsStr> + Display>(key: K) -> Option<OsString>
143143
env::var_os(key)
144144
}
145145

146+
fn run_llvm_as() {
147+
// Check if libintrinsics.ll exists
148+
let libintrinsics_path = Path::new("libintrinsics.ll");
149+
if !libintrinsics_path.exists() {
150+
fail("libintrinsics.ll not found");
151+
}
152+
153+
println!("cargo:rerun-if-changed=libintrinsics.ll");
154+
155+
let mut cmd = Command::new("llvm-as");
156+
cmd.arg("libintrinsics.ll");
157+
158+
let output = match cmd.stderr(Stdio::inherit()).output() {
159+
Ok(status) => status,
160+
Err(e) => fail(&format!(
161+
"failed to execute llvm-as: {:?}\nerror: {}",
162+
cmd, e
163+
)),
164+
};
165+
166+
if !output.status.success() {
167+
fail(&format!(
168+
"llvm-as failed: {:?}\nstatus: {}",
169+
cmd, output.status
170+
));
171+
}
172+
}
173+
146174
fn rustc_llvm_build() {
147175
let target = env::var("TARGET").expect("TARGET was not set");
148176
let llvm_config = find_llvm_config(&target);
149177

150-
let required_components = &["ipo", "bitreader", "bitwriter", "lto", "nvptx"];
178+
let required_components = &[
179+
"ipo",
180+
"bitreader",
181+
"bitwriter",
182+
"lto",
183+
"nvptx",
184+
];
151185

152186
let components = output(Command::new(&llvm_config).arg("--components"));
153187
let mut components = components.split_whitespace().collect::<Vec<_>>();
@@ -165,6 +199,9 @@ fn rustc_llvm_build() {
165199
println!("cargo:rustc-cfg=llvm_component=\"{}\"", component);
166200
}
167201

202+
// Run llvm-as on libintrinsics.ll
203+
run_llvm_as();
204+
168205
// Link in our own LLVM shims, compiled with the same flags as LLVM
169206
let mut cmd = Command::new(&llvm_config);
170207
cmd.arg("--cxxflags");
700 Bytes
Binary file not shown.

crates/rustc_codegen_nvvm/libintrinsics.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
; if you update this make sure to update libintrinsics.bc by running llvm-as (make sure you are using llvm-7 or it won't work when
66
; loaded into libnvvm).
77
source_filename = "libintrinsics"
8-
target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"
8+
target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-i128:128:128-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64-a:8:8"
99
target triple = "nvptx64-nvidia-cuda"
1010

1111
; thread ----
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#ifndef INCLUDED_RUSTC_LLVM_LLVMWRAPPER_H
2+
#define INCLUDED_RUSTC_LLVM_LLVMWRAPPER_H
3+
4+
#include "SuppressLLVMWarnings.h"
5+
6+
#include "llvm/Config/llvm-config.h" // LLVM_VERSION_MAJOR, LLVM_VERSION_MINOR
7+
#include "llvm/Support/raw_ostream.h" // llvm::raw_ostream
8+
#include <cstddef> // size_t etc
9+
#include <cstdint> // uint64_t etc
10+
11+
#define LLVM_VERSION_GE(major, minor) \
12+
(LLVM_VERSION_MAJOR > (major) || \
13+
LLVM_VERSION_MAJOR == (major) && LLVM_VERSION_MINOR >= (minor))
14+
15+
#define LLVM_VERSION_LT(major, minor) (!LLVM_VERSION_GE((major), (minor)))
16+
17+
extern "C" void LLVMRustSetLastError(const char *);
18+
19+
enum class LLVMRustResult { Success, Failure };
20+
21+
typedef struct OpaqueRustString *RustStringRef;
22+
typedef struct LLVMOpaqueTwine *LLVMTwineRef;
23+
typedef struct LLVMOpaqueSMDiagnostic *LLVMSMDiagnosticRef;
24+
25+
extern "C" void LLVMRustStringWriteImpl(RustStringRef buf,
26+
const char *slice_ptr,
27+
size_t slice_len);
28+
29+
class RawRustStringOstream : public llvm::raw_ostream {
30+
RustStringRef Str;
31+
uint64_t Pos;
32+
33+
void write_impl(const char *Ptr, size_t Size) override {
34+
LLVMRustStringWriteImpl(Str, Ptr, Size);
35+
Pos += Size;
36+
}
37+
38+
uint64_t current_pos() const override { return Pos; }
39+
40+
public:
41+
explicit RawRustStringOstream(RustStringRef Str) : Str(Str), Pos(0) {}
42+
43+
~RawRustStringOstream() {
44+
// LLVM requires this.
45+
flush();
46+
}
47+
};
48+
49+
#endif // INCLUDED_RUSTC_LLVM_LLVMWRAPPER_H

0 commit comments

Comments
 (0)