@@ -8,6 +8,8 @@ use std::{
8
8
str:: FromStr ,
9
9
} ;
10
10
11
+ use strum:: IntoEnumIterator ;
12
+
11
13
use cust_raw:: nvvm_sys;
12
14
13
15
pub use cust_raw:: nvvm_sys:: LIBDEVICE_BITCODE ;
@@ -255,6 +257,10 @@ impl FromStr for NvvmOption {
255
257
"72" => NvvmArch :: Compute72 ,
256
258
"75" => NvvmArch :: Compute75 ,
257
259
"80" => NvvmArch :: Compute80 ,
260
+ "86" => NvvmArch :: Compute86 ,
261
+ "87" => NvvmArch :: Compute87 ,
262
+ "89" => NvvmArch :: Compute89 ,
263
+ "90" => NvvmArch :: Compute90 ,
258
264
_ => return Err ( "unknown arch" ) ,
259
265
} ;
260
266
Self :: Arch ( arch)
@@ -265,7 +271,7 @@ impl FromStr for NvvmOption {
265
271
}
266
272
267
273
/// Nvvm architecture, default is `Compute52`
268
- #[ derive( Debug , Clone , Copy , PartialEq , Eq ) ]
274
+ #[ derive( Debug , Clone , Copy , PartialEq , Eq , strum :: EnumIter ) ]
269
275
pub enum NvvmArch {
270
276
Compute35 ,
271
277
Compute37 ,
@@ -279,6 +285,10 @@ pub enum NvvmArch {
279
285
Compute72 ,
280
286
Compute75 ,
281
287
Compute80 ,
288
+ Compute86 ,
289
+ Compute87 ,
290
+ Compute89 ,
291
+ Compute90 ,
282
292
}
283
293
284
294
impl Display for NvvmArch {
@@ -295,6 +305,53 @@ impl Default for NvvmArch {
295
305
}
296
306
}
297
307
308
+ impl NvvmArch {
309
+ /// Get the numeric capability value (e.g., 35 for Compute35)
310
+ pub fn capability_value ( & self ) -> u32 {
311
+ match self {
312
+ Self :: Compute35 => 35 ,
313
+ Self :: Compute37 => 37 ,
314
+ Self :: Compute50 => 50 ,
315
+ Self :: Compute52 => 52 ,
316
+ Self :: Compute53 => 53 ,
317
+ Self :: Compute60 => 60 ,
318
+ Self :: Compute61 => 61 ,
319
+ Self :: Compute62 => 62 ,
320
+ Self :: Compute70 => 70 ,
321
+ Self :: Compute72 => 72 ,
322
+ Self :: Compute75 => 75 ,
323
+ Self :: Compute80 => 80 ,
324
+ Self :: Compute86 => 86 ,
325
+ Self :: Compute87 => 87 ,
326
+ Self :: Compute89 => 89 ,
327
+ Self :: Compute90 => 90 ,
328
+ }
329
+ }
330
+
331
+ /// Get the target feature string (e.g., "compute_35" for Compute35)
332
+ pub fn target_feature ( & self ) -> String {
333
+ let cap = self . capability_value ( ) ;
334
+ format ! ( "compute_{cap}" )
335
+ }
336
+
337
+ /// Get all target features up to and including this architecture.
338
+ /// This ensures that `cfg(target_feature = "compute_50")` works on compute_60+ devices.
339
+ pub fn all_target_features ( & self ) -> Vec < String > {
340
+ let current = self . capability_value ( ) ;
341
+
342
+ NvvmArch :: iter ( )
343
+ . filter ( |arch| arch. capability_value ( ) <= current)
344
+ . map ( |arch| arch. target_feature ( ) )
345
+ . collect ( )
346
+ }
347
+
348
+ /// Create an iterator over all architectures from Compute35 up to and including this one
349
+ pub fn iter_up_to ( & self ) -> impl Iterator < Item = Self > {
350
+ let current = self . capability_value ( ) ;
351
+ NvvmArch :: iter ( ) . filter ( move |arch| arch. capability_value ( ) <= current)
352
+ }
353
+ }
354
+
298
355
pub struct NvvmProgram {
299
356
raw : nvvm_sys:: nvvmProgram ,
300
357
}
@@ -409,6 +466,194 @@ impl NvvmProgram {
409
466
mod tests {
410
467
use std:: str:: FromStr ;
411
468
469
+ #[ test]
470
+ fn nvvm_arch_capability_value ( ) {
471
+ use crate :: NvvmArch ;
472
+
473
+ assert_eq ! ( NvvmArch :: Compute35 . capability_value( ) , 35 ) ;
474
+ assert_eq ! ( NvvmArch :: Compute37 . capability_value( ) , 37 ) ;
475
+ assert_eq ! ( NvvmArch :: Compute50 . capability_value( ) , 50 ) ;
476
+ assert_eq ! ( NvvmArch :: Compute52 . capability_value( ) , 52 ) ;
477
+ assert_eq ! ( NvvmArch :: Compute53 . capability_value( ) , 53 ) ;
478
+ assert_eq ! ( NvvmArch :: Compute60 . capability_value( ) , 60 ) ;
479
+ assert_eq ! ( NvvmArch :: Compute61 . capability_value( ) , 61 ) ;
480
+ assert_eq ! ( NvvmArch :: Compute62 . capability_value( ) , 62 ) ;
481
+ assert_eq ! ( NvvmArch :: Compute70 . capability_value( ) , 70 ) ;
482
+ assert_eq ! ( NvvmArch :: Compute72 . capability_value( ) , 72 ) ;
483
+ assert_eq ! ( NvvmArch :: Compute75 . capability_value( ) , 75 ) ;
484
+ assert_eq ! ( NvvmArch :: Compute80 . capability_value( ) , 80 ) ;
485
+ assert_eq ! ( NvvmArch :: Compute86 . capability_value( ) , 86 ) ;
486
+ assert_eq ! ( NvvmArch :: Compute87 . capability_value( ) , 87 ) ;
487
+ assert_eq ! ( NvvmArch :: Compute89 . capability_value( ) , 89 ) ;
488
+ assert_eq ! ( NvvmArch :: Compute90 . capability_value( ) , 90 ) ;
489
+ }
490
+
491
+ #[ test]
492
+ fn nvvm_arch_target_feature_format ( ) {
493
+ use crate :: NvvmArch ;
494
+
495
+ assert_eq ! ( NvvmArch :: Compute35 . target_feature( ) , "compute_35" ) ;
496
+ assert_eq ! ( NvvmArch :: Compute61 . target_feature( ) , "compute_61" ) ;
497
+ assert_eq ! ( NvvmArch :: Compute90 . target_feature( ) , "compute_90" ) ;
498
+ }
499
+
500
+ #[ test]
501
+ fn nvvm_arch_all_target_features_includes_lower_capabilities ( ) {
502
+ use crate :: NvvmArch ;
503
+
504
+ // Compute35 only includes itself
505
+ let compute35_features = NvvmArch :: Compute35 . all_target_features ( ) ;
506
+ assert_eq ! ( compute35_features, vec![ "compute_35" ] ) ;
507
+
508
+ // Compute50 includes all lower capabilities
509
+ let compute50_features = NvvmArch :: Compute50 . all_target_features ( ) ;
510
+ assert_eq ! (
511
+ compute50_features,
512
+ vec![ "compute_35" , "compute_37" , "compute_50" ]
513
+ ) ;
514
+
515
+ // Compute61 includes all lower capabilities
516
+ let compute61_features = NvvmArch :: Compute61 . all_target_features ( ) ;
517
+ assert_eq ! (
518
+ compute61_features,
519
+ vec![
520
+ "compute_35" ,
521
+ "compute_37" ,
522
+ "compute_50" ,
523
+ "compute_52" ,
524
+ "compute_53" ,
525
+ "compute_60" ,
526
+ "compute_61"
527
+ ]
528
+ ) ;
529
+
530
+ // Compute90 includes all capabilities
531
+ let compute90_features = NvvmArch :: Compute90 . all_target_features ( ) ;
532
+ assert_eq ! (
533
+ compute90_features,
534
+ vec![
535
+ "compute_35" ,
536
+ "compute_37" ,
537
+ "compute_50" ,
538
+ "compute_52" ,
539
+ "compute_53" ,
540
+ "compute_60" ,
541
+ "compute_61" ,
542
+ "compute_62" ,
543
+ "compute_70" ,
544
+ "compute_72" ,
545
+ "compute_75" ,
546
+ "compute_80" ,
547
+ "compute_86" ,
548
+ "compute_87" ,
549
+ "compute_89" ,
550
+ "compute_90"
551
+ ]
552
+ ) ;
553
+ }
554
+
555
+ #[ test]
556
+ fn target_feature_synthesis_supports_conditional_compilation_patterns ( ) {
557
+ use crate :: NvvmArch ;
558
+
559
+ // When targeting Compute61, should enable all lower capabilities
560
+ let features = NvvmArch :: Compute61 . all_target_features ( ) ;
561
+
562
+ // Should enable compute_60 (for f64 atomics)
563
+ assert ! ( features. contains( & "compute_60" . to_string( ) ) ) ;
564
+
565
+ // Should enable compute_50 (for 64-bit integer atomics)
566
+ assert ! ( features. contains( & "compute_50" . to_string( ) ) ) ;
567
+
568
+ // Should enable compute_35 (baseline)
569
+ assert ! ( features. contains( & "compute_35" . to_string( ) ) ) ;
570
+
571
+ // Should enable the target itself
572
+ assert ! ( features. contains( & "compute_61" . to_string( ) ) ) ;
573
+
574
+ // Should NOT enable higher capabilities
575
+ assert ! ( !features. contains( & "compute_62" . to_string( ) ) ) ;
576
+ assert ! ( !features. contains( & "compute_70" . to_string( ) ) ) ;
577
+ }
578
+
579
+ #[ test]
580
+ fn target_feature_synthesis_enables_correct_cfg_patterns ( ) {
581
+ use crate :: NvvmArch ;
582
+
583
+ // Test that targeting Compute70 enables appropriate cfg patterns
584
+ let features = NvvmArch :: Compute70 . all_target_features ( ) ;
585
+
586
+ // These should all be true for compute_70 target
587
+ let expected_enabled = [
588
+ "compute_35" ,
589
+ "compute_37" ,
590
+ "compute_50" ,
591
+ "compute_52" ,
592
+ "compute_53" ,
593
+ "compute_60" ,
594
+ "compute_61" ,
595
+ "compute_62" ,
596
+ "compute_70" ,
597
+ ] ;
598
+
599
+ for feature in expected_enabled {
600
+ assert ! (
601
+ features. contains( & feature. to_string( ) ) ,
602
+ "Compute70 should enable {feature} for cfg(target_feature = \" {feature}\" )"
603
+ ) ;
604
+ }
605
+
606
+ // These should NOT be enabled for compute_70 target
607
+ let expected_disabled = [ "compute_72" , "compute_75" , "compute_80" , "compute_90" ] ;
608
+
609
+ for feature in expected_disabled {
610
+ assert ! (
611
+ !features. contains( & feature. to_string( ) ) ,
612
+ "Compute70 should NOT enable {feature}"
613
+ ) ;
614
+ }
615
+ }
616
+
617
+ #[ test]
618
+ fn nvvm_arch_iter_up_to_includes_only_lower_or_equal ( ) {
619
+ use crate :: NvvmArch ;
620
+
621
+ // Compute35 only includes itself
622
+ let archs: Vec < _ > = NvvmArch :: Compute35 . iter_up_to ( ) . collect ( ) ;
623
+ assert_eq ! ( archs, vec![ NvvmArch :: Compute35 ] ) ;
624
+
625
+ // Compute52 includes all up to 52
626
+ let archs: Vec < _ > = NvvmArch :: Compute52 . iter_up_to ( ) . collect ( ) ;
627
+ assert_eq ! (
628
+ archs,
629
+ vec![
630
+ NvvmArch :: Compute35 ,
631
+ NvvmArch :: Compute37 ,
632
+ NvvmArch :: Compute50 ,
633
+ NvvmArch :: Compute52 ,
634
+ ]
635
+ ) ;
636
+
637
+ // Compute75 includes all up to 75
638
+ let archs: Vec < _ > = NvvmArch :: Compute75 . iter_up_to ( ) . collect ( ) ;
639
+ assert_eq ! (
640
+ archs,
641
+ vec![
642
+ NvvmArch :: Compute35 ,
643
+ NvvmArch :: Compute37 ,
644
+ NvvmArch :: Compute50 ,
645
+ NvvmArch :: Compute52 ,
646
+ NvvmArch :: Compute53 ,
647
+ NvvmArch :: Compute60 ,
648
+ NvvmArch :: Compute61 ,
649
+ NvvmArch :: Compute62 ,
650
+ NvvmArch :: Compute70 ,
651
+ NvvmArch :: Compute72 ,
652
+ NvvmArch :: Compute75 ,
653
+ ]
654
+ ) ;
655
+ }
656
+
412
657
#[ test]
413
658
fn options_parse_correctly ( ) {
414
659
use crate :: NvvmArch :: * ;
0 commit comments