@@ -4,10 +4,11 @@ use std::path::{Path, PathBuf};
4
4
use std:: sync:: Arc ;
5
5
use std:: { fs, slice, str} ;
6
6
7
- use libc:: { c_char, c_int, c_void, size_t} ;
7
+ use libc:: { c_char, c_int, c_uint , c_void, size_t} ;
8
8
use llvm:: {
9
9
LLVMRustLLVMHasZlibCompressionForDebugSymbols , LLVMRustLLVMHasZstdCompressionForDebugSymbols ,
10
10
} ;
11
+ use rustc_ast:: expand:: autodiff_attrs:: AutoDiffItem ;
11
12
use rustc_codegen_ssa:: back:: link:: ensure_removed;
12
13
use rustc_codegen_ssa:: back:: versioned_llvm_target;
13
14
use rustc_codegen_ssa:: back:: write:: {
@@ -28,7 +29,7 @@ use rustc_session::config::{
28
29
use rustc_span:: InnerSpan ;
29
30
use rustc_span:: symbol:: sym;
30
31
use rustc_target:: spec:: { CodeModel , RelocModel , SanitizerSet , SplitDebuginfo , TlsModel } ;
31
- use tracing:: debug;
32
+ use tracing:: { debug, trace } ;
32
33
33
34
use crate :: back:: lto:: ThinBuffer ;
34
35
use crate :: back:: owned_target_machine:: OwnedTargetMachine ;
@@ -40,8 +41,14 @@ use crate::errors::{
40
41
CopyBitcode , FromLlvmDiag , FromLlvmOptimizationDiag , LlvmError , UnknownCompression ,
41
42
WithLlvmError , WriteBytecode ,
42
43
} ;
44
+ use crate :: llvm:: AttributePlace :: Function ;
43
45
use crate :: llvm:: diagnostic:: OptimizationDiagnosticKind :: * ;
44
- use crate :: llvm:: { self , DiagnosticInfo , PassManager } ;
46
+ use crate :: llvm:: {
47
+ self , AttributeKind , DiagnosticInfo , LLVMGetFirstFunction ,
48
+ LLVMGetNextFunction , LLVMGetStringAttributeAtIndex , LLVMIsEnumAttribute , LLVMIsStringAttribute ,
49
+ LLVMRemoveStringAttributeAtIndex , LLVMRustGetEnumAttributeAtIndex ,
50
+ LLVMRustRemoveEnumAttributeAtIndex , PassManager ,
51
+ } ;
45
52
use crate :: type_:: Type ;
46
53
use crate :: { LlvmCodegenBackend , ModuleLlvm , base, common, llvm_util} ;
47
54
@@ -517,9 +524,38 @@ pub(crate) unsafe fn llvm_optimize(
517
524
config : & ModuleConfig ,
518
525
opt_level : config:: OptLevel ,
519
526
opt_stage : llvm:: OptStage ,
527
+ skip_size_increasing_opts : bool ,
520
528
) -> Result < ( ) , FatalError > {
521
- let unroll_loops =
522
- opt_level != config:: OptLevel :: Size && opt_level != config:: OptLevel :: SizeMin ;
529
+ // Enzyme:
530
+ // The whole point of compiler based AD is to differentiate optimized IR instead of unoptimized
531
+ // source code. However, benchmarks show that optimizations increasing the code size
532
+ // tend to reduce AD performance. Therefore deactivate them before AD, then differentiate the code
533
+ // and finally re-optimize the module, now with all optimizations available.
534
+ // TODO: In a future update we could figure out how to only optimize individual functions getting
535
+ // differentiated.
536
+
537
+ let unroll_loops;
538
+ let vectorize_slp;
539
+ let vectorize_loop;
540
+
541
+ // When we build rustc with enzyme/autodiff support, we want to postpone size-increasing
542
+ // optimizations until after differentiation. FIXME(ZuseZ4): Before shipping on nightly,
543
+ // we should make this more granular, or at least check that the user has at least one autodiff
544
+ // call in their code, to justify altering the compilation pipeline.
545
+ if skip_size_increasing_opts && cfg ! ( llvm_enzyme) {
546
+ unroll_loops = false ;
547
+ vectorize_slp = false ;
548
+ vectorize_loop = false ;
549
+ } else {
550
+ unroll_loops =
551
+ opt_level != config:: OptLevel :: Size && opt_level != config:: OptLevel :: SizeMin ;
552
+ vectorize_slp = config. vectorize_slp ;
553
+ vectorize_loop = config. vectorize_loop ;
554
+ }
555
+ trace ! (
556
+ "Enzyme: Running with unroll_loops: {}, vectorize_slp: {}, vectorize_loop: {}" ,
557
+ unroll_loops, vectorize_slp, vectorize_loop
558
+ ) ;
523
559
let using_thin_buffers = opt_stage == llvm:: OptStage :: PreLinkThinLTO || config. bitcode_needed ( ) ;
524
560
let pgo_gen_path = get_pgo_gen_path ( config) ;
525
561
let pgo_use_path = get_pgo_use_path ( config) ;
@@ -583,8 +619,8 @@ pub(crate) unsafe fn llvm_optimize(
583
619
using_thin_buffers,
584
620
config. merge_functions ,
585
621
unroll_loops,
586
- config . vectorize_slp ,
587
- config . vectorize_loop ,
622
+ vectorize_slp,
623
+ vectorize_loop,
588
624
config. no_builtins ,
589
625
config. emit_lifetime_markers ,
590
626
sanitizer_options. as_ref ( ) ,
@@ -606,6 +642,115 @@ pub(crate) unsafe fn llvm_optimize(
606
642
result. into_result ( ) . map_err ( |( ) | llvm_err ( dcx, LlvmError :: RunLlvmPasses ) )
607
643
}
608
644
645
+ pub ( crate ) fn differentiate (
646
+ module : & ModuleCodegen < ModuleLlvm > ,
647
+ cgcx : & CodegenContext < LlvmCodegenBackend > ,
648
+ diff_items : Vec < AutoDiffItem > ,
649
+ config : & ModuleConfig ,
650
+ ) -> Result < ( ) , FatalError > {
651
+ for item in & diff_items {
652
+ trace ! ( "{}" , item) ;
653
+ }
654
+
655
+ let llmod = module. module_llvm . llmod ( ) ;
656
+ let llcx = & module. module_llvm . llcx ;
657
+ let diag_handler = cgcx. create_dcx ( ) ;
658
+
659
+ // Before dumping the module, we want all the tt to become part of the module.
660
+ for item in diff_items. iter ( ) {
661
+ let name = CString :: new ( item. source . clone ( ) ) . unwrap ( ) ;
662
+ let fn_def: Option < & llvm:: Value > =
663
+ unsafe { llvm:: LLVMGetNamedFunction ( llmod, name. as_ptr ( ) ) } ;
664
+ let fn_def = match fn_def {
665
+ Some ( x) => x,
666
+ None => {
667
+ return Err ( llvm_err ( diag_handler. handle ( ) , LlvmError :: PrepareAutoDiff {
668
+ src : item. source . clone ( ) ,
669
+ target : item. target . clone ( ) ,
670
+ error : "could not find source function" . to_owned ( ) ,
671
+ } ) ) ;
672
+ }
673
+ } ;
674
+ let target_name = CString :: new ( item. target . clone ( ) ) . unwrap ( ) ;
675
+ debug ! ( "target name: {:?}" , & target_name) ;
676
+ let fn_target: Option < & llvm:: Value > =
677
+ unsafe { llvm:: LLVMGetNamedFunction ( llmod, target_name. as_ptr ( ) ) } ;
678
+ let fn_target = match fn_target {
679
+ Some ( x) => x,
680
+ None => {
681
+ return Err ( llvm_err ( diag_handler. handle ( ) , LlvmError :: PrepareAutoDiff {
682
+ src : item. source . clone ( ) ,
683
+ target : item. target . clone ( ) ,
684
+ error : "could not find target function" . to_owned ( ) ,
685
+ } ) ) ;
686
+ }
687
+ } ;
688
+
689
+ crate :: builder:: generate_enzyme_call ( llmod, llcx, fn_def, fn_target, item. attrs . clone ( ) ) ;
690
+ }
691
+
692
+ // We needed the SanitizeHWAddress attribute to prevent LLVM from optimizing enums in a way
693
+ // which Enzyme doesn't understand.
694
+ unsafe {
695
+ let mut f = LLVMGetFirstFunction ( llmod) ;
696
+ loop {
697
+ if let Some ( lf) = f {
698
+ f = LLVMGetNextFunction ( lf) ;
699
+ let myhwattr = "enzyme_hw" ;
700
+ let attr = LLVMGetStringAttributeAtIndex (
701
+ lf,
702
+ c_uint:: MAX ,
703
+ myhwattr. as_ptr ( ) as * const c_char ,
704
+ myhwattr. as_bytes ( ) . len ( ) as c_uint ,
705
+ ) ;
706
+ if LLVMIsStringAttribute ( attr) {
707
+ LLVMRemoveStringAttributeAtIndex (
708
+ lf,
709
+ c_uint:: MAX ,
710
+ myhwattr. as_ptr ( ) as * const c_char ,
711
+ myhwattr. as_bytes ( ) . len ( ) as c_uint ,
712
+ ) ;
713
+ } else {
714
+ LLVMRustRemoveEnumAttributeAtIndex (
715
+ lf,
716
+ c_uint:: MAX ,
717
+ AttributeKind :: SanitizeHWAddress ,
718
+ ) ;
719
+ }
720
+ } else {
721
+ break ;
722
+ }
723
+ }
724
+ }
725
+
726
+ if let Some ( opt_level) = config. opt_level {
727
+ let opt_stage = match cgcx. lto {
728
+ Lto :: Fat => llvm:: OptStage :: PreLinkFatLTO ,
729
+ Lto :: Thin | Lto :: ThinLocal => llvm:: OptStage :: PreLinkThinLTO ,
730
+ _ if cgcx. opts . cg . linker_plugin_lto . enabled ( ) => llvm:: OptStage :: PreLinkThinLTO ,
731
+ _ => llvm:: OptStage :: PreLinkNoLTO ,
732
+ } ;
733
+ // This is our second opt call, so now we run all opts,
734
+ // to make sure we get the best performance.
735
+ let skip_size_increasing_opts = false ;
736
+ trace ! ( "running Module Optimization after differentiation" ) ;
737
+ unsafe {
738
+ llvm_optimize (
739
+ cgcx,
740
+ diag_handler. handle ( ) ,
741
+ module,
742
+ config,
743
+ opt_level,
744
+ opt_stage,
745
+ skip_size_increasing_opts,
746
+ ) ?
747
+ } ;
748
+ }
749
+ trace ! ( "done with differentiate()" ) ;
750
+
751
+ Ok ( ( ) )
752
+ }
753
+
609
754
// Unsafe due to LLVM calls.
610
755
pub ( crate ) unsafe fn optimize (
611
756
cgcx : & CodegenContext < LlvmCodegenBackend > ,
@@ -628,14 +773,57 @@ pub(crate) unsafe fn optimize(
628
773
unsafe { llvm:: LLVMWriteBitcodeToFile ( llmod, out. as_ptr ( ) ) } ;
629
774
}
630
775
776
+ // This code enables Enzyme to differentiate code containing Rust enums.
777
+ // By adding the SanitizeHWAddress attribute we prevent LLVM from Optimizing
778
+ // away the enums and allows Enzyme to understand why a value can be of different types in
779
+ // different code sections. We remove this attribute after Enzyme is done, to not affect the
780
+ // rest of the compilation.
781
+ //#[cfg(llvm_enzyme)]
782
+ unsafe {
783
+ let mut f = LLVMGetFirstFunction ( llmod) ;
784
+ loop {
785
+ if let Some ( lf) = f {
786
+ f = LLVMGetNextFunction ( lf) ;
787
+ let myhwattr = "enzyme_hw" ;
788
+ let prevattr = LLVMRustGetEnumAttributeAtIndex (
789
+ lf,
790
+ c_uint:: MAX ,
791
+ AttributeKind :: SanitizeHWAddress ,
792
+ ) ;
793
+ if LLVMIsEnumAttribute ( prevattr) {
794
+ let attr = llvm:: CreateAttrString ( llcx, myhwattr) ;
795
+ crate :: attributes:: apply_to_llfn ( lf, Function , & [ attr] ) ;
796
+ } else {
797
+ let attr = AttributeKind :: SanitizeHWAddress . create_attr ( llcx) ;
798
+ crate :: attributes:: apply_to_llfn ( lf, Function , & [ attr] ) ;
799
+ }
800
+ } else {
801
+ break ;
802
+ }
803
+ }
804
+ }
805
+
631
806
if let Some ( opt_level) = config. opt_level {
632
807
let opt_stage = match cgcx. lto {
633
808
Lto :: Fat => llvm:: OptStage :: PreLinkFatLTO ,
634
809
Lto :: Thin | Lto :: ThinLocal => llvm:: OptStage :: PreLinkThinLTO ,
635
810
_ if cgcx. opts . cg . linker_plugin_lto . enabled ( ) => llvm:: OptStage :: PreLinkThinLTO ,
636
811
_ => llvm:: OptStage :: PreLinkNoLTO ,
637
812
} ;
638
- return unsafe { llvm_optimize ( cgcx, dcx, module, config, opt_level, opt_stage) } ;
813
+
814
+ // If we know that we will later run AD, then we disable vectorization and loop unrolling
815
+ let skip_size_increasing_opts = cfg ! ( llvm_enzyme) ;
816
+ return unsafe {
817
+ llvm_optimize (
818
+ cgcx,
819
+ dcx,
820
+ module,
821
+ config,
822
+ opt_level,
823
+ opt_stage,
824
+ skip_size_increasing_opts,
825
+ )
826
+ } ;
639
827
}
640
828
Ok ( ( ) )
641
829
}
0 commit comments