Skip to content

Commit 119beb4

Browse files
committed
adressing reviewer feedback
1 parent 4d9468e commit 119beb4

File tree

4 files changed

+50
-22
lines changed

4 files changed

+50
-22
lines changed

compiler/rustc_codegen_gcc/src/lib.rs

+9
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ use gccjit::{CType, Context, OptimizationLevel};
9393
#[cfg(feature = "master")]
9494
use gccjit::{TargetInfo, Version};
9595
use rustc_ast::expand::allocator::AllocatorKind;
96+
use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
9697
use rustc_codegen_ssa::back::lto::{LtoModuleCodegen, SerializedModule, ThinModule};
9798
use rustc_codegen_ssa::back::write::{
9899
CodegenContext, FatLtoInput, ModuleConfig, TargetMachineFactoryFn,
@@ -439,6 +440,14 @@ impl WriteBackendMethods for GccCodegenBackend {
439440
) -> Result<ModuleCodegen<Self::Module>, FatalError> {
440441
back::write::link(cgcx, dcx, modules)
441442
}
443+
fn autodiff(
444+
_cgcx: &CodegenContext<Self>,
445+
_module: &ModuleCodegen<Self::Module>,
446+
_diff_fncs: Vec<AutoDiffItem>,
447+
_config: &ModuleConfig,
448+
) -> Result<(), FatalError> {
449+
unimplemented!()
450+
}
442451
}
443452

444453
/// This is the entrypoint for a hot plugged rustc_codegen_gccjit

compiler/rustc_codegen_llvm/messages.ftl

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
codegen_llvm_autodiff_without_lto = using the autodiff feature requires using fat-lto
2+
13
codegen_llvm_copy_bitcode = failed to copy bitcode to object file: {$err}
24
35
codegen_llvm_dynamic_linking_with_lto =
@@ -47,6 +49,8 @@ codegen_llvm_parse_bitcode_with_llvm_err = failed to parse bitcode for LTO modul
4749
codegen_llvm_parse_target_machine_config =
4850
failed to parse target machine config to target machine: {$error}
4951
52+
codegen_llvm_prepare_autodiff = failed to prepare autodiff: src: {$src}, target: {$target}, {$error}
53+
codegen_llvm_prepare_autodiff_with_llvm_err = failed to prepare autodiff: {$llvm_err}, src: {$src}, target: {$target}, {$error}
5054
codegen_llvm_prepare_thin_lto_context = failed to prepare thin LTO context
5155
codegen_llvm_prepare_thin_lto_context_with_llvm_err = failed to prepare thin LTO context: {$llvm_err}
5256
@@ -56,10 +60,6 @@ codegen_llvm_prepare_thin_lto_module_with_llvm_err = failed to prepare thin LTO
5660
codegen_llvm_run_passes = failed to run LLVM passes
5761
codegen_llvm_run_passes_with_llvm_err = failed to run LLVM passes: {$llvm_err}
5862
59-
codegen_llvm_prepare_autodiff = failed to prepare autodiff: src: {$src}, target: {$target}, {$error}
60-
codegen_llvm_prepare_autodiff_with_llvm_err = failed to prepare autodiff: {$llvm_err}, src: {$src}, target: {$target}, {$error}
61-
codegen_llvm_autodiff_without_lto = using the autodiff feature requires using fat-lto
62-
6363
codegen_llvm_sanitizer_memtag_requires_mte =
6464
`-Zsanitizer=memtag` requires `-Ctarget-feature=+mte`
6565

compiler/rustc_codegen_llvm/src/back/write.rs

+3-5
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,7 @@ pub(crate) unsafe fn llvm_optimize(
525525
// source code. However, benchmarks show that optimizations increasing the code size
526526
// tend to reduce AD performance. Therefore deactivate them before AD, then differentiate the code
527527
// and finally re-optimize the module, now with all optimizations available.
528-
// TODO: In a future update we could figure out how to only optimize individual functions getting
528+
// FIXME(ZuseZ4): In a future update we could figure out how to only optimize individual functions getting
529529
// differentiated.
530530

531531
let unroll_loops;
@@ -683,8 +683,7 @@ pub(crate) fn differentiate(
683683
crate::builder::generate_enzyme_call(llmod, llcx, fn_def, fn_target, item.attrs.clone());
684684
}
685685

686-
// FIXME(ZuseZ4): In the following upstream PR, we want to add code to handle SanitizeHWAddress,
687-
// to prevent some illegal/unsupported optimizations.
686+
// FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts
688687

689688
if let Some(opt_level) = config.opt_level {
690689
let opt_stage = match cgcx.lto {
@@ -736,8 +735,7 @@ pub(crate) unsafe fn optimize(
736735
unsafe { llvm::LLVMWriteBitcodeToFile(llmod, out.as_ptr()) };
737736
}
738737

739-
// FIXME(ZuseZ4): In the following PR, we have to add code to apply the sanitize_hwaddress
740-
// attribute to all functions in the module, to prevent some illegal/unsupported optimizations.
738+
// FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts
741739

742740
if let Some(opt_level) = config.opt_level {
743741
let opt_stage = match cgcx.lto {

compiler/rustc_codegen_llvm/src/builder.rs

+34-13
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,15 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
4848
}
4949
}
5050

51-
// The lowering of one `#[autodiff]` macro happens in multiple steps.
52-
// First we transalte generate a new dummy function, who's llvm-ir we now have as outer_fn.
53-
// We kept track of the original function to which the `#[autodiff]` macro was applied to, which we
54-
// now have as fn_to_diff. In our current implementation, we use the enzyme pass to carry out the
55-
// differentiation, following naming and calling conventions documented here: <https://enzyme.mit.edu/getting_started/CallingConvention/>
56-
//
57-
// Our `outer_fn` had some dummy code inserted at higher levels, so we first remove most of the
58-
// existing body. We then insert an `__enzyme_<autodiff/fwddiff>_<unique_id>` call, which the pass
59-
// will then pick up. FIXME(ZuseZ4): We will later want to upstream safety checks to the `outer_fn`,
60-
// in order to cover some assumptions of enzyme/autodiff, which could lead to UB otherwise.
51+
/// When differentiating `fn_to_diff`, take a `outer_fn` and generate another
52+
/// function with expected naming and calling conventions[^1] which will be
53+
/// discovered by the enzyme LLVM pass and its body populated with the differentiated
54+
/// `fn_to_diff`. `outer_fn` is then modified to have a call to the generated
55+
/// function and handle the differences between the Rust calling convention and
56+
/// Enzyme.
57+
/// [^1]: <https://enzyme.mit.edu/getting_started/CallingConvention/>
58+
// FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to
59+
// cover some assumptions of enzyme/autodiff, which could lead to UB otherwise.
6160
pub(crate) fn generate_enzyme_call<'ll>(
6261
llmod: &'ll llvm::Module,
6362
llcx: &'ll llvm::Context,
@@ -69,7 +68,7 @@ pub(crate) fn generate_enzyme_call<'ll>(
6968
let output = attrs.ret_activity;
7069

7170
// We have to pick the name depending on whether we want forward or reverse mode autodiff.
72-
// FIXME(ZuseZ4): The new pass based approach should not need the *First method anymore, since
71+
// FIXME(ZuseZ4): The new pass based approach should not need the {Forward/Reverse}First method anymore, since
7372
// it will handle higher-order derivatives correctly automatically (in theory). Currently
7473
// higher-order derivatives fail, so we should debug that before adjusting this code.
7574
let mut ad_name: String = match attrs.mode {
@@ -87,16 +86,38 @@ pub(crate) fn generate_enzyme_call<'ll>(
8786
let outer_fn_name = std::ffi::CStr::from_bytes_with_nul(name).unwrap().to_str().unwrap();
8887
ad_name.push_str(outer_fn_name.to_string().as_str());
8988

90-
// Assuming that our fn_to_diff is the fnc square, want to generate the following llvm-ir, which
91-
// would allow the enzyme pass to generate a function body for `__enzyme_autodiff_square`
89+
// Let us assume the user wrote the following function square:
9290
//
91+
// ```llvm
92+
// define double @square(double %x) {
93+
// entry:
94+
// %0 = fmul double %x, %x
95+
// ret double %0
96+
// }
97+
// ```
98+
//
99+
// The user now applies autodiff to the function square, in which case fn_to_diff will be `square`.
100+
// Our macro generates the following placeholder code (slightly simplified):
101+
//
102+
// ```llvm
103+
// define double @dsquare(double %x) {
104+
// ; placeholder code
105+
// return 0.0;
106+
// }
107+
// ```
108+
//
109+
// so our `outer_fn` will be `dsquare`. The unsafe code section below now removes the placeholder
110+
// code and inserts an autodiff call. We also add a declaration for the __enzyme_autodiff call.
111+
// Again, the arguments to all functions are slightly simplified.
112+
// ```llvm
93113
// declare double @__enzyme_autodiff_square(...)
94114
//
95115
// define double @dsquare(double %x) {
96116
// entry:
97117
// %0 = tail call double (...) @__enzyme_autodiff_square(double (double)* nonnull @square, double %x)
98118
// ret double %0
99119
// }
120+
// ```
100121
unsafe {
101122
// On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
102123
// arguments. We do however need to declare them with their correct return type.

0 commit comments

Comments
 (0)