@@ -48,16 +48,15 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
48
48
}
49
49
}
50
50
51
- // The lowering of one `#[autodiff]` macro happens in multiple steps.
52
- // First we transalte generate a new dummy function, who's llvm-ir we now have as outer_fn.
53
- // We kept track of the original function to which the `#[autodiff]` macro was applied to, which we
54
- // now have as fn_to_diff. In our current implementation, we use the enzyme pass to carry out the
55
- // differentiation, following naming and calling conventions documented here: <https://enzyme.mit.edu/getting_started/CallingConvention/>
56
- //
57
- // Our `outer_fn` had some dummy code inserted at higher levels, so we first remove most of the
58
- // existing body. We then insert an `__enzyme_<autodiff/fwddiff>_<unique_id>` call, which the pass
59
- // will then pick up. FIXME(ZuseZ4): We will later want to upstream safety checks to the `outer_fn`,
60
- // in order to cover some assumptions of enzyme/autodiff, which could lead to UB otherwise.
51
+ /// When differentiating `fn_to_diff`, take a `outer_fn` and generate another
52
+ /// function with expected naming and calling conventions[^1] which will be
53
+ /// discovered by the enzyme LLVM pass and its body populated with the differentiated
54
+ /// `fn_to_diff`. `outer_fn` is then modified to have a call to the generated
55
+ /// function and handle the differences between the Rust calling convention and
56
+ /// Enzyme.
57
+ /// [^1]: <https://enzyme.mit.edu/getting_started/CallingConvention/>
58
+ // FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to
59
+ // cover some assumptions of enzyme/autodiff, which could lead to UB otherwise.
61
60
pub ( crate ) fn generate_enzyme_call < ' ll > (
62
61
llmod : & ' ll llvm:: Module ,
63
62
llcx : & ' ll llvm:: Context ,
@@ -69,7 +68,7 @@ pub(crate) fn generate_enzyme_call<'ll>(
69
68
let output = attrs. ret_activity ;
70
69
71
70
// We have to pick the name depending on whether we want forward or reverse mode autodiff.
72
- // FIXME(ZuseZ4): The new pass based approach should not need the * First method anymore, since
71
+ // FIXME(ZuseZ4): The new pass based approach should not need the {Forward/Reverse} First method anymore, since
73
72
// it will handle higher-order derivatives correctly automatically (in theory). Currently
74
73
// higher-order derivatives fail, so we should debug that before adjusting this code.
75
74
let mut ad_name: String = match attrs. mode {
@@ -87,16 +86,38 @@ pub(crate) fn generate_enzyme_call<'ll>(
87
86
let outer_fn_name = std:: ffi:: CStr :: from_bytes_with_nul ( name) . unwrap ( ) . to_str ( ) . unwrap ( ) ;
88
87
ad_name. push_str ( outer_fn_name. to_string ( ) . as_str ( ) ) ;
89
88
90
- // Assuming that our fn_to_diff is the fnc square, want to generate the following llvm-ir, which
91
- // would allow the enzyme pass to generate a function body for `__enzyme_autodiff_square`
89
+ // Let us assume the user wrote the following function square:
92
90
//
91
+ // ```llvm
92
+ // define double @square(double %x) {
93
+ // entry:
94
+ // %0 = fmul double %x, %x
95
+ // ret double %0
96
+ // }
97
+ // ```
98
+ //
99
+ // The user now applies autodiff to the function square, in which case fn_to_diff will be `square`.
100
+ // Our macro generates the following placeholder code (slightly simplified):
101
+ //
102
+ // ```llvm
103
+ // define double @dsquare(double %x) {
104
+ // ; placeholder code
105
+ // return 0.0;
106
+ // }
107
+ // ```
108
+ //
109
+ // so our `outer_fn` will be `dsquare`. The unsafe code section below now removes the placeholder
110
+ // code and inserts an autodiff call. We also add a declaration for the __enzyme_autodiff call.
111
+ // Again, the arguments to all functions are slightly simplified.
112
+ // ```llvm
93
113
// declare double @__enzyme_autodiff_square(...)
94
114
//
95
115
// define double @dsquare(double %x) {
96
116
// entry:
97
117
// %0 = tail call double (...) @__enzyme_autodiff_square(double (double)* nonnull @square, double %x)
98
118
// ret double %0
99
119
// }
120
+ // ```
100
121
unsafe {
101
122
// On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
102
123
// arguments. We do however need to declare them with their correct return type.
0 commit comments