Skip to content

Commit 5a70839

Browse files
committed
Add target_feature support for compute_*
This lets us gate code to virtual architectures at compile time using `cfg()`.
1 parent 3b646e6 commit 5a70839

File tree

7 files changed

+538
-6
lines changed

7 files changed

+538
-6
lines changed

crates/cuda_builder/src/lib.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,23 @@ pub struct CudaBuilder {
9393
/// the GTX 1030, GTX 1050, GTX 1080, Tesla P40, etc. We default to this because
9494
/// Maxwell (5.x) will be deprecated in CUDA 12 and we anticipate for that. Moreover,
9595
/// `6.x` contains support for things like f64 atomic add and half precision float ops.
96+
///
97+
/// ## Target Features for Conditional Compilation
98+
///
99+
/// The chosen architecture enables a target feature that can be used for
100+
/// conditional compilation with `#[cfg(target_feature = "compute_XX")]`.
101+
/// This feature means "at least this capability", matching NVIDIA's semantics.
102+
///
103+
/// For other patterns (exact ranges, maximum capabilities), use boolean `cfg` logic.
104+
/// See the compute capabilities guide for examples.
105+
///
106+
/// For example, with `.arch(NvvmArch::Compute61)`:
107+
/// ```ignore
108+
/// #[cfg(target_feature = "compute_61")]
109+
/// {
110+
/// // Code that requires compute capability 6.1+
111+
/// }
112+
/// ```
96113
pub arch: NvvmArch,
97114
/// Flush denormal values to zero when performing single-precision floating point operations.
98115
/// `false` by default.
@@ -229,6 +246,11 @@ impl CudaBuilder {
229246
/// NOTE that this does not necessarily mean that code using a certain capability
230247
/// will not work on older capabilities. It means that if it uses certain
231248
/// features it may not work.
249+
///
250+
/// ## Target Features for Conditional Compilation
251+
///
252+
/// The chosen architecture enables target features for conditional compilation.
253+
/// See the documentation on the `arch` field for more details.
232254
pub fn arch(mut self, arch: NvvmArch) -> Self {
233255
self.arch = arch;
234256
self

crates/nvvm/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ readme = "../../README.md"
1010

1111
[dependencies]
1212
cust_raw = { path = "../cust_raw", default-features = false, features = ["nvvm"] }
13+
strum = { version = "0.27", features = ["derive"] }

crates/nvvm/src/lib.rs

Lines changed: 246 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ use std::{
88
str::FromStr,
99
};
1010

11+
use strum::IntoEnumIterator;
12+
1113
use cust_raw::nvvm_sys;
1214

1315
pub use cust_raw::nvvm_sys::LIBDEVICE_BITCODE;
@@ -255,6 +257,10 @@ impl FromStr for NvvmOption {
255257
"72" => NvvmArch::Compute72,
256258
"75" => NvvmArch::Compute75,
257259
"80" => NvvmArch::Compute80,
260+
"86" => NvvmArch::Compute86,
261+
"87" => NvvmArch::Compute87,
262+
"89" => NvvmArch::Compute89,
263+
"90" => NvvmArch::Compute90,
258264
_ => return Err("unknown arch"),
259265
};
260266
Self::Arch(arch)
@@ -265,7 +271,7 @@ impl FromStr for NvvmOption {
265271
}
266272

267273
/// Nvvm architecture, default is `Compute52`
268-
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
274+
#[derive(Debug, Clone, Copy, PartialEq, Eq, strum::EnumIter)]
269275
pub enum NvvmArch {
270276
Compute35,
271277
Compute37,
@@ -279,6 +285,10 @@ pub enum NvvmArch {
279285
Compute72,
280286
Compute75,
281287
Compute80,
288+
Compute86,
289+
Compute87,
290+
Compute89,
291+
Compute90,
282292
}
283293

284294
impl Display for NvvmArch {
@@ -295,6 +305,53 @@ impl Default for NvvmArch {
295305
}
296306
}
297307

308+
impl NvvmArch {
309+
/// Get the numeric capability value (e.g., 35 for Compute35)
310+
pub fn capability_value(&self) -> u32 {
311+
match self {
312+
Self::Compute35 => 35,
313+
Self::Compute37 => 37,
314+
Self::Compute50 => 50,
315+
Self::Compute52 => 52,
316+
Self::Compute53 => 53,
317+
Self::Compute60 => 60,
318+
Self::Compute61 => 61,
319+
Self::Compute62 => 62,
320+
Self::Compute70 => 70,
321+
Self::Compute72 => 72,
322+
Self::Compute75 => 75,
323+
Self::Compute80 => 80,
324+
Self::Compute86 => 86,
325+
Self::Compute87 => 87,
326+
Self::Compute89 => 89,
327+
Self::Compute90 => 90,
328+
}
329+
}
330+
331+
/// Get the target feature string (e.g., "compute_35" for Compute35)
332+
pub fn target_feature(&self) -> String {
333+
let cap = self.capability_value();
334+
format!("compute_{cap}")
335+
}
336+
337+
/// Get all target features up to and including this architecture.
338+
/// This ensures that `cfg(target_feature = "compute_50")` works on compute_60+ devices.
339+
pub fn all_target_features(&self) -> Vec<String> {
340+
let current = self.capability_value();
341+
342+
NvvmArch::iter()
343+
.filter(|arch| arch.capability_value() <= current)
344+
.map(|arch| arch.target_feature())
345+
.collect()
346+
}
347+
348+
/// Create an iterator over all architectures from Compute35 up to and including this one
349+
pub fn iter_up_to(&self) -> impl Iterator<Item = Self> {
350+
let current = self.capability_value();
351+
NvvmArch::iter().filter(move |arch| arch.capability_value() <= current)
352+
}
353+
}
354+
298355
pub struct NvvmProgram {
299356
raw: nvvm_sys::nvvmProgram,
300357
}
@@ -409,6 +466,194 @@ impl NvvmProgram {
409466
mod tests {
410467
use std::str::FromStr;
411468

469+
#[test]
470+
fn nvvm_arch_capability_value() {
471+
use crate::NvvmArch;
472+
473+
assert_eq!(NvvmArch::Compute35.capability_value(), 35);
474+
assert_eq!(NvvmArch::Compute37.capability_value(), 37);
475+
assert_eq!(NvvmArch::Compute50.capability_value(), 50);
476+
assert_eq!(NvvmArch::Compute52.capability_value(), 52);
477+
assert_eq!(NvvmArch::Compute53.capability_value(), 53);
478+
assert_eq!(NvvmArch::Compute60.capability_value(), 60);
479+
assert_eq!(NvvmArch::Compute61.capability_value(), 61);
480+
assert_eq!(NvvmArch::Compute62.capability_value(), 62);
481+
assert_eq!(NvvmArch::Compute70.capability_value(), 70);
482+
assert_eq!(NvvmArch::Compute72.capability_value(), 72);
483+
assert_eq!(NvvmArch::Compute75.capability_value(), 75);
484+
assert_eq!(NvvmArch::Compute80.capability_value(), 80);
485+
assert_eq!(NvvmArch::Compute86.capability_value(), 86);
486+
assert_eq!(NvvmArch::Compute87.capability_value(), 87);
487+
assert_eq!(NvvmArch::Compute89.capability_value(), 89);
488+
assert_eq!(NvvmArch::Compute90.capability_value(), 90);
489+
}
490+
491+
#[test]
492+
fn nvvm_arch_target_feature_format() {
493+
use crate::NvvmArch;
494+
495+
assert_eq!(NvvmArch::Compute35.target_feature(), "compute_35");
496+
assert_eq!(NvvmArch::Compute61.target_feature(), "compute_61");
497+
assert_eq!(NvvmArch::Compute90.target_feature(), "compute_90");
498+
}
499+
500+
#[test]
501+
fn nvvm_arch_all_target_features_includes_lower_capabilities() {
502+
use crate::NvvmArch;
503+
504+
// Compute35 only includes itself
505+
let compute35_features = NvvmArch::Compute35.all_target_features();
506+
assert_eq!(compute35_features, vec!["compute_35"]);
507+
508+
// Compute50 includes all lower capabilities
509+
let compute50_features = NvvmArch::Compute50.all_target_features();
510+
assert_eq!(
511+
compute50_features,
512+
vec!["compute_35", "compute_37", "compute_50"]
513+
);
514+
515+
// Compute61 includes all lower capabilities
516+
let compute61_features = NvvmArch::Compute61.all_target_features();
517+
assert_eq!(
518+
compute61_features,
519+
vec![
520+
"compute_35",
521+
"compute_37",
522+
"compute_50",
523+
"compute_52",
524+
"compute_53",
525+
"compute_60",
526+
"compute_61"
527+
]
528+
);
529+
530+
// Compute90 includes all capabilities
531+
let compute90_features = NvvmArch::Compute90.all_target_features();
532+
assert_eq!(
533+
compute90_features,
534+
vec![
535+
"compute_35",
536+
"compute_37",
537+
"compute_50",
538+
"compute_52",
539+
"compute_53",
540+
"compute_60",
541+
"compute_61",
542+
"compute_62",
543+
"compute_70",
544+
"compute_72",
545+
"compute_75",
546+
"compute_80",
547+
"compute_86",
548+
"compute_87",
549+
"compute_89",
550+
"compute_90"
551+
]
552+
);
553+
}
554+
555+
#[test]
556+
fn target_feature_synthesis_supports_conditional_compilation_patterns() {
557+
use crate::NvvmArch;
558+
559+
// When targeting Compute61, should enable all lower capabilities
560+
let features = NvvmArch::Compute61.all_target_features();
561+
562+
// Should enable compute_60 (for f64 atomics)
563+
assert!(features.contains(&"compute_60".to_string()));
564+
565+
// Should enable compute_50 (for 64-bit integer atomics)
566+
assert!(features.contains(&"compute_50".to_string()));
567+
568+
// Should enable compute_35 (baseline)
569+
assert!(features.contains(&"compute_35".to_string()));
570+
571+
// Should enable the target itself
572+
assert!(features.contains(&"compute_61".to_string()));
573+
574+
// Should NOT enable higher capabilities
575+
assert!(!features.contains(&"compute_62".to_string()));
576+
assert!(!features.contains(&"compute_70".to_string()));
577+
}
578+
579+
#[test]
580+
fn target_feature_synthesis_enables_correct_cfg_patterns() {
581+
use crate::NvvmArch;
582+
583+
// Test that targeting Compute70 enables appropriate cfg patterns
584+
let features = NvvmArch::Compute70.all_target_features();
585+
586+
// These should all be true for compute_70 target
587+
let expected_enabled = [
588+
"compute_35",
589+
"compute_37",
590+
"compute_50",
591+
"compute_52",
592+
"compute_53",
593+
"compute_60",
594+
"compute_61",
595+
"compute_62",
596+
"compute_70",
597+
];
598+
599+
for feature in expected_enabled {
600+
assert!(
601+
features.contains(&feature.to_string()),
602+
"Compute70 should enable {feature} for cfg(target_feature = \"{feature}\")"
603+
);
604+
}
605+
606+
// These should NOT be enabled for compute_70 target
607+
let expected_disabled = ["compute_72", "compute_75", "compute_80", "compute_90"];
608+
609+
for feature in expected_disabled {
610+
assert!(
611+
!features.contains(&feature.to_string()),
612+
"Compute70 should NOT enable {feature}"
613+
);
614+
}
615+
}
616+
617+
#[test]
618+
fn nvvm_arch_iter_up_to_includes_only_lower_or_equal() {
619+
use crate::NvvmArch;
620+
621+
// Compute35 only includes itself
622+
let archs: Vec<_> = NvvmArch::Compute35.iter_up_to().collect();
623+
assert_eq!(archs, vec![NvvmArch::Compute35]);
624+
625+
// Compute52 includes all up to 52
626+
let archs: Vec<_> = NvvmArch::Compute52.iter_up_to().collect();
627+
assert_eq!(
628+
archs,
629+
vec![
630+
NvvmArch::Compute35,
631+
NvvmArch::Compute37,
632+
NvvmArch::Compute50,
633+
NvvmArch::Compute52,
634+
]
635+
);
636+
637+
// Compute75 includes all up to 75
638+
let archs: Vec<_> = NvvmArch::Compute75.iter_up_to().collect();
639+
assert_eq!(
640+
archs,
641+
vec![
642+
NvvmArch::Compute35,
643+
NvvmArch::Compute37,
644+
NvvmArch::Compute50,
645+
NvvmArch::Compute52,
646+
NvvmArch::Compute53,
647+
NvvmArch::Compute60,
648+
NvvmArch::Compute61,
649+
NvvmArch::Compute62,
650+
NvvmArch::Compute70,
651+
NvvmArch::Compute72,
652+
NvvmArch::Compute75,
653+
]
654+
);
655+
}
656+
412657
#[test]
413658
fn options_parse_correctly() {
414659
use crate::NvvmArch::*;

0 commit comments

Comments
 (0)