1
- #![ allow( unused_imports) ]
2
- #![ allow( unused_variables) ]
3
- use std:: ffi:: { CStr , CString } ;
1
+ use std:: ffi:: CString ;
4
2
use std:: io:: { self , Write } ;
5
3
use std:: path:: { Path , PathBuf } ;
6
4
use std:: sync:: Arc ;
7
5
use std:: { fs, slice, str} ;
8
6
9
7
use libc:: { c_char, c_int, c_uint, c_void, size_t} ;
10
8
use llvm:: {
11
- IntPredicate , LLVMGetNextBasicBlock , LLVMRustDISetInstMetadata ,
9
+ IntPredicate ,
12
10
LLVMRustLLVMHasZlibCompressionForDebugSymbols , LLVMRustLLVMHasZstdCompressionForDebugSymbols ,
13
11
} ;
14
12
use rustc_ast:: expand:: autodiff_attrs:: { AutoDiffItem , DiffActivity , DiffMode } ;
@@ -47,27 +45,26 @@ use crate::errors::{
47
45
} ;
48
46
use crate :: llvm:: diagnostic:: OptimizationDiagnosticKind ;
49
47
use crate :: llvm:: {
50
- self , enzyme_rust_forward_diff, enzyme_rust_reverse_diff, AttributeKind , BasicBlock ,
48
+ self , enzyme_rust_forward_diff, enzyme_rust_reverse_diff, AttributeKind ,
51
49
CreateEnzymeLogic , CreateTypeAnalysis , DiagnosticInfo , EnzymeLogicRef , EnzymeTypeAnalysisRef ,
52
- FreeTypeAnalysis , LLVMAddFunction , LLVMAppendBasicBlockInContext , LLVMBuildCall2 ,
50
+ FreeTypeAnalysis , LLVMAppendBasicBlockInContext , LLVMBuildCall2 ,
53
51
LLVMBuildCondBr , LLVMBuildExtractValue , LLVMBuildICmp , LLVMBuildRet , LLVMBuildRetVoid ,
54
52
LLVMCountParams , LLVMCountStructElementTypes , LLVMCreateBuilderInContext ,
55
- LLVMCreateStringAttribute , LLVMDeleteFunction , LLVMDisposeBuilder , LLVMDumpModule ,
56
- LLVMGetBasicBlockTerminator , LLVMGetFirstBasicBlock , LLVMGetFirstFunction ,
57
- LLVMGetModuleContext , LLVMGetNextFunction , LLVMGetParams , LLVMGetReturnType ,
53
+ LLVMCreateStringAttribute , LLVMDisposeBuilder , LLVMDumpModule ,
54
+ LLVMGetFirstBasicBlock , LLVMGetFirstFunction ,
55
+ LLVMGetNextFunction , LLVMGetParams , LLVMGetReturnType ,
58
56
LLVMGetStringAttributeAtIndex , LLVMGlobalGetValueType , LLVMIsEnumAttribute ,
59
57
LLVMIsStringAttribute , LLVMMetadataAsValue , LLVMPositionBuilderAtEnd ,
60
- LLVMRemoveStringAttributeAtIndex , LLVMReplaceAllUsesWith , LLVMRustAddEnumAttributeAtIndex ,
61
- LLVMRustAddFunctionAttributes , LLVMRustDIGetInstMetadata , LLVMRustDIGetInstMetadataOfTy ,
62
- LLVMRustEraseBBFromParent , LLVMRustEraseInstBefore , LLVMRustEraseInstFromParent ,
58
+ LLVMRemoveStringAttributeAtIndex , LLVMRustAddEnumAttributeAtIndex ,
59
+ LLVMRustAddFunctionAttributes , LLVMRustDIGetInstMetadata ,
60
+ LLVMRustEraseInstBefore , LLVMRustEraseInstFromParent ,
63
61
LLVMRustGetEnumAttributeAtIndex , LLVMRustGetFunctionType , LLVMRustGetLastInstruction ,
64
- LLVMRustGetTerminator , LLVMRustHasDbgMetadata , LLVMRustHasMetadata ,
65
- LLVMRustRemoveEnumAttributeAtIndex , LLVMRustRemoveFncAttr ,
66
- LLVMRustgetFirstNonPHIOrDbgOrLifetime , LLVMSetValueName2 , LLVMVerifyFunction ,
62
+ LLVMRustGetTerminator , LLVMRustHasMetadata ,
63
+ LLVMRustRemoveEnumAttributeAtIndex ,
64
+ LLVMVerifyFunction ,
67
65
LLVMVoidTypeInContext , PassManager , Value ,
68
66
} ;
69
67
use crate :: type_:: Type ;
70
- use crate :: typetree:: to_enzyme_typetree;
71
68
use crate :: { base, common, llvm_util, DiffTypeTree , LlvmCodegenBackend , ModuleLlvm } ;
72
69
73
70
pub fn llvm_err < ' a > ( dcx : DiagCtxtHandle < ' _ > , err : LlvmError < ' a > ) -> FatalError {
@@ -669,50 +666,27 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
669
666
670
667
// DESIGN:
671
668
// Today we have our placeholder function, and our Enzyme generated one.
672
- // We create a wrapper function and delete the placeholder body.
673
- // We then call the wrapper from the placeholder.
669
+ // We create a wrapper function and delete the placeholder body. You can see the
670
+ // placeholder by running `cargo expand` on an autodiff invocation. We call the wrapper
671
+ // from the placeholder. This function is a bit longer, because it matches the Rust level
672
+ // autodiff macro with LLVM level Enzyme autodiff expectations.
674
673
//
675
- // Soon, we won't delete the whole placeholder, but just the loop,
676
- // and the two inline asm sections. For now we can still call the wrapper.
677
- // In the future we call our Enzyme generated function directly and unwrap the return
678
- // struct in our original placeholder.
679
- //
680
- // define internal double @_ZN2ad3bar17ha38374e821680177E(ptr align 8 %0, ptr align 8 %1, double %2) unnamed_addr #17 !dbg !13678 {
681
- // %4 = alloca double, align 8
682
- // %5 = alloca ptr, align 8
683
- // %6 = alloca ptr, align 8
684
- // %7 = alloca { ptr, double }, align 8
685
- // store ptr %0, ptr %6, align 8
686
- // call void @llvm.dbg.declare(metadata ptr %6, metadata !13682, metadata !DIExpression()), !dbg !13685
687
- // store ptr %1, ptr %5, align 8
688
- // call void @llvm.dbg.declare(metadata ptr %5, metadata !13683, metadata !DIExpression()), !dbg !13685
689
- // store double %2, ptr %4, align 8
690
- // call void @llvm.dbg.declare(metadata ptr %4, metadata !13684, metadata !DIExpression()), !dbg !13686
691
- // call void asm sideeffect alignstack inteldialect "NOP", "~{dirflag},~{fpsr},~{flags},~{memory}"(), !dbg !13687, !srcloc !23
692
- // %8 = call double @_ZN2ad3foo17h95b548a9411653b2E(ptr align 8 %0), !dbg !13687
693
- // %9 = call double @_ZN4core4hint9black_box17h7bd67a41b0f12bdfE(double %8), !dbg !13687
694
- // store ptr %1, ptr %7, align 8, !dbg !13687
695
- // %10 = getelementptr inbounds { ptr, double }, ptr %7, i32 0, i32 1, !dbg !13687
696
- // store double %2, ptr %10, align 8, !dbg !13687
697
- // %11 = getelementptr inbounds { ptr, double }, ptr %7, i32 0, i32 0, !dbg !13687
698
- // %12 = load ptr, ptr %11, align 8, !dbg !13687, !nonnull !23, !align !1047, !noundef !23
699
- // %13 = getelementptr inbounds { ptr, double }, ptr %7, i32 0, i32 1, !dbg !13687
700
- // %14 = load double, ptr %13, align 8, !dbg !13687, !noundef !23
701
- // %15 = call { ptr, double } @_ZN4core4hint9black_box17h669f3b22afdcb487E(ptr align 8 %12, double %14), !dbg !13687
702
- // %16 = extractvalue { ptr, double } %15, 0, !dbg !13687
703
- // %17 = extractvalue { ptr, double } %15, 1, !dbg !13687
704
- // br label %18, !dbg !13687
705
- //
706
- //18: ; preds = %18, %3
707
- // br label %18, !dbg !13687
674
+ // Think of computing the derivative with respect to &[f32] by marking it as duplicated.
675
+ // The user will then pass an extra &mut [f32] and we want add the derivative to that.
676
+ // On LLVM/Enzyme level, &[f32] however becomes `ptr, i64` and we mark ptr as duplicated,
677
+ // and i64 (len) as const. Enzyme will then expect `ptr, ptr, i64` as arguments. See how the
678
+ // second i64 from the mut slice isn't used? That's why we add a safety check to assert
679
+ // that the second (mut) slice is at least as long as the first (const) slice. Otherwise,
680
+ // Enzyme would write out of bounds if the first (const) slice is longer than the second.
708
681
709
682
unsafe fn create_call < ' a > (
710
683
tgt : & ' a Value ,
711
684
src : & ' a Value ,
712
- rev_mode : bool ,
713
685
llmod : & ' a llvm:: Module ,
714
686
llcx : & llvm:: Context ,
715
- size_positions : & [ usize ] ,
687
+ // FIXME: Instead of recomputing the positions as we do it below, we should
688
+ // start using this list of positions that indicate length integers.
689
+ _size_positions : & [ usize ] ,
716
690
ad : & [ AutoDiff ] ,
717
691
) {
718
692
unsafe {
@@ -756,9 +730,10 @@ unsafe fn create_call<'a>(
756
730
inner_pos += 1 ;
757
731
outer_pos += 1 ;
758
732
} else {
759
- // out: (ptr, <>int1, ptr, int2)
733
+ // out: rust: (&[f32], &mut [f32])
734
+ // out: llvm: (ptr, <>int1, ptr, int2)
760
735
// inner: (ptr, <>ptr, int)
761
- // goal: (ptr, ptr, int1), skipping int2
736
+ // goal: call (ptr, ptr, int1), skipping int2
762
737
// we are here: <>
763
738
assert ! ( llvm:: LLVMRustGetTypeKind ( outer_arg_ty) == llvm:: TypeKind :: Integer ) ;
764
739
assert ! ( llvm:: LLVMRustGetTypeKind ( inner_arg_ty) == llvm:: TypeKind :: Pointer ) ;
@@ -872,17 +847,17 @@ unsafe fn create_call<'a>(
872
847
) ;
873
848
874
849
// Add dummy dbg info to our newly generated call, if we have any.
875
- let inst = LLVMRustgetFirstNonPHIOrDbgOrLifetime ( bb) . unwrap ( ) ;
876
850
let md_ty = llvm:: LLVMGetMDKindIDInContext (
877
851
llcx,
878
852
"dbg" . as_ptr ( ) as * const c_char ,
879
853
"dbg" . len ( ) as c_uint ,
880
854
) ;
881
855
856
+
882
857
if LLVMRustHasMetadata ( last_inst, md_ty) {
883
858
let md = LLVMRustDIGetInstMetadata ( last_inst) ;
884
859
let md_val = LLVMMetadataAsValue ( llcx, md) ;
885
- let md2 = llvm:: LLVMSetMetadata ( struct_ret, md_ty, md_val) ;
860
+ let _md2 = llvm:: LLVMSetMetadata ( struct_ret, md_ty, md_val) ;
886
861
} else {
887
862
trace ! ( "No dbg info" ) ;
888
863
}
@@ -938,8 +913,8 @@ unsafe fn get_panic_name(llmod: &llvm::Module) -> CString {
938
913
// For now we only check if shadow arguments are large enough. In this case we look for Rust panic
939
914
// functions in the module and call it. Due to hashing we can't hardcode the panic function name.
940
915
// Note: This worked even for panic=abort tests so seems solid enough for now.
941
- // TODO : Pick a panic function which allows displaying an errormessage .
942
- // TODO : We probably want to keep a handle at higher level and pass it down instead of searching.
916
+ // FIXME : Pick a panic function which allows displaying an error message .
917
+ // FIXME : We probably want to keep a handle at higher level and pass it down instead of searching.
943
918
unsafe fn add_panic_msg_to_global < ' a > (
944
919
llmod : & ' a llvm:: Module ,
945
920
llcx : & ' a llvm:: Context ,
@@ -961,7 +936,7 @@ unsafe fn add_panic_msg_to_global<'a>(
961
936
let i8_array_type = LLVMArrayType2 ( LLVMInt8TypeInContext ( llcx) , msg_len as u64 ) ;
962
937
963
938
// Create the string constant
964
- let string_const_val =
939
+ let _string_const_val =
965
940
LLVMConstStringInContext2 ( llcx, cmsg. as_ptr ( ) as * const i8 , msg_len as usize , 0 ) ;
966
941
967
942
// Create the array initializer
@@ -1098,8 +1073,7 @@ pub(crate) unsafe fn enzyme_ad(
1098
1073
1099
1074
let f_return_type = LLVMGetReturnType ( LLVMGlobalGetValueType ( res) ) ;
1100
1075
1101
- let rev_mode = item. attrs . mode == DiffMode :: Reverse ;
1102
- create_call ( target_fnc, res, rev_mode, llmod, llcx, & size_positions, ad) ;
1076
+ create_call ( target_fnc, res, llmod, llcx, & size_positions, ad) ;
1103
1077
// TODO: implement drop for wrapper type?
1104
1078
FreeTypeAnalysis ( type_analysis) ;
1105
1079
}
@@ -1133,10 +1107,6 @@ pub(crate) unsafe fn differentiate(
1133
1107
1134
1108
// Before dumping the module, we want all the tt to become part of the module.
1135
1109
for ( i, item) in diff_items. iter ( ) . enumerate ( ) {
1136
- let llvm_data_layout = unsafe { llvm:: LLVMGetDataLayoutStr ( & * llmod) } ;
1137
- let llvm_data_layout =
1138
- std:: str:: from_utf8 ( unsafe { CStr :: from_ptr ( llvm_data_layout) } . to_bytes ( ) )
1139
- . expect ( "got a non-UTF8 data-layout from LLVM" ) ;
1140
1110
let tt: FncTree = FncTree { args : item. inputs . clone ( ) , ret : item. output . clone ( ) } ;
1141
1111
let name = CString :: new ( item. source . clone ( ) ) . unwrap ( ) ;
1142
1112
let fn_def: & llvm:: Value =
0 commit comments